feat: Added ndarray-safetensors
This commit is contained in:
44
src/cli.rs
44
src/cli.rs
@@ -13,6 +13,12 @@ pub enum SubCommand {
|
||||
Detect(Detect),
|
||||
#[clap(name = "list")]
|
||||
List(List),
|
||||
#[clap(name = "query")]
|
||||
Query(Query),
|
||||
#[clap(name = "similar")]
|
||||
Similar(Similar),
|
||||
#[clap(name = "stats")]
|
||||
Stats(Stats),
|
||||
#[clap(name = "completions")]
|
||||
Completions { shell: clap_complete::Shell },
|
||||
}
|
||||
@@ -58,12 +64,50 @@ pub struct Detect {
|
||||
pub nms_threshold: f32,
|
||||
#[clap(short, long, default_value_t = 8)]
|
||||
pub batch_size: usize,
|
||||
#[clap(short = 'd', long)]
|
||||
pub database: Option<PathBuf>,
|
||||
#[clap(long, default_value = "facenet")]
|
||||
pub model_name: String,
|
||||
#[clap(long)]
|
||||
pub save_to_db: bool,
|
||||
pub image: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct List {}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Query {
|
||||
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||
pub database: PathBuf,
|
||||
#[clap(short, long)]
|
||||
pub image_id: Option<i64>,
|
||||
#[clap(short, long)]
|
||||
pub face_id: Option<i64>,
|
||||
#[clap(long)]
|
||||
pub show_embeddings: bool,
|
||||
#[clap(long)]
|
||||
pub show_landmarks: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Similar {
|
||||
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||
pub database: PathBuf,
|
||||
#[clap(short, long)]
|
||||
pub face_id: i64,
|
||||
#[clap(short, long, default_value_t = 0.7)]
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 10)]
|
||||
pub limit: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Stats {
|
||||
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||
pub database: PathBuf,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
pub fn completions(shell: clap_complete::Shell) {
|
||||
let mut command = <Cli as clap::CommandFactory>::command();
|
||||
|
||||
548
src/database.rs
Normal file
548
src/database.rs
Normal file
@@ -0,0 +1,548 @@
|
||||
use crate::errors::{Error, Result};
|
||||
use crate::facedet::{FaceDetectionOutput, FaceLandmarks};
|
||||
use bounding_box::Aabb2;
|
||||
use error_stack::ResultExt;
|
||||
use rusqlite::{Connection, OptionalExtension, params};
|
||||
use std::path::Path;
|
||||
|
||||
/// Database connection and operations for face detection results
|
||||
pub struct FaceDatabase {
|
||||
conn: Connection,
|
||||
}
|
||||
|
||||
/// Represents a stored image record
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ImageRecord {
|
||||
pub id: i64,
|
||||
pub file_path: String,
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// Represents a stored face detection record
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FaceRecord {
|
||||
pub id: i64,
|
||||
pub image_id: i64,
|
||||
pub bbox_x1: f32,
|
||||
pub bbox_y1: f32,
|
||||
pub bbox_x2: f32,
|
||||
pub bbox_y2: f32,
|
||||
pub confidence: f32,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// Represents stored face landmarks
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LandmarkRecord {
|
||||
pub id: i64,
|
||||
pub face_id: i64,
|
||||
pub left_eye_x: f32,
|
||||
pub left_eye_y: f32,
|
||||
pub right_eye_x: f32,
|
||||
pub right_eye_y: f32,
|
||||
pub nose_x: f32,
|
||||
pub nose_y: f32,
|
||||
pub left_mouth_x: f32,
|
||||
pub left_mouth_y: f32,
|
||||
pub right_mouth_x: f32,
|
||||
pub right_mouth_y: f32,
|
||||
}
|
||||
|
||||
/// Represents a stored face embedding
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbeddingRecord {
|
||||
pub id: i64,
|
||||
pub face_id: i64,
|
||||
pub embedding: Vec<f32>,
|
||||
pub model_name: String,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
impl FaceDatabase {
|
||||
/// Create a new database connection and initialize tables
|
||||
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
|
||||
let conn = Connection::open(db_path).change_context(Error)?;
|
||||
let db = Self { conn };
|
||||
db.create_tables()?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
/// Create an in-memory database for testing
|
||||
pub fn in_memory() -> Result<Self> {
|
||||
let conn = Connection::open_in_memory().change_context(Error)?;
|
||||
let db = Self { conn };
|
||||
db.create_tables()?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
/// Create all necessary database tables
|
||||
fn create_tables(&self) -> Result<()> {
|
||||
// Images table
|
||||
self.conn
|
||||
.execute(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_path TEXT NOT NULL UNIQUE,
|
||||
width INTEGER NOT NULL,
|
||||
height INTEGER NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"#,
|
||||
[],
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
// Faces table
|
||||
self.conn
|
||||
.execute(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS faces (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
image_id INTEGER NOT NULL,
|
||||
bbox_x1 REAL NOT NULL,
|
||||
bbox_y1 REAL NOT NULL,
|
||||
bbox_x2 REAL NOT NULL,
|
||||
bbox_y2 REAL NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
|
||||
)
|
||||
"#,
|
||||
[],
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
// Landmarks table
|
||||
self.conn
|
||||
.execute(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS landmarks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
face_id INTEGER NOT NULL,
|
||||
left_eye_x REAL NOT NULL,
|
||||
left_eye_y REAL NOT NULL,
|
||||
right_eye_x REAL NOT NULL,
|
||||
right_eye_y REAL NOT NULL,
|
||||
nose_x REAL NOT NULL,
|
||||
nose_y REAL NOT NULL,
|
||||
left_mouth_x REAL NOT NULL,
|
||||
left_mouth_y REAL NOT NULL,
|
||||
right_mouth_x REAL NOT NULL,
|
||||
right_mouth_y REAL NOT NULL,
|
||||
FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE
|
||||
)
|
||||
"#,
|
||||
[],
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
// Embeddings table
|
||||
self.conn
|
||||
.execute(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS embeddings (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
face_id INTEGER NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE
|
||||
)
|
||||
"#,
|
||||
[],
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
// Create indexes for better performance
|
||||
self.conn
|
||||
.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_faces_image_id ON faces (image_id)",
|
||||
[],
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
self.conn
|
||||
.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_landmarks_face_id ON landmarks (face_id)",
|
||||
[],
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
self.conn
|
||||
.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_embeddings_face_id ON embeddings (face_id)",
|
||||
[],
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Store image metadata and return the image ID
|
||||
pub fn store_image(&self, file_path: &str, width: u32, height: u32) -> Result<i64> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)")
|
||||
.change_context(Error)?;
|
||||
|
||||
stmt.execute(params![file_path, width, height])
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(self.conn.last_insert_rowid())
|
||||
}
|
||||
|
||||
/// Store face detection results
|
||||
pub fn store_face_detections(
|
||||
&self,
|
||||
image_id: i64,
|
||||
detection_output: &FaceDetectionOutput,
|
||||
) -> Result<Vec<i64>> {
|
||||
let mut face_ids = Vec::new();
|
||||
|
||||
for (i, bbox) in detection_output.bbox.iter().enumerate() {
|
||||
let confidence = detection_output.confidence.get(i).copied().unwrap_or(0.0);
|
||||
|
||||
let face_id = self.store_face(image_id, bbox, confidence)?;
|
||||
face_ids.push(face_id);
|
||||
|
||||
// Store landmarks if available
|
||||
if let Some(landmarks) = detection_output.landmark.get(i) {
|
||||
self.store_landmarks(face_id, landmarks)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(face_ids)
|
||||
}
|
||||
|
||||
/// Store a single face detection
|
||||
pub fn store_face(&self, image_id: i64, bbox: &Aabb2<usize>, confidence: f32) -> Result<i64> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#"
|
||||
INSERT INTO faces (image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||||
"#,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
stmt.execute(params![
|
||||
image_id,
|
||||
bbox.x1() as f32,
|
||||
bbox.y1() as f32,
|
||||
bbox.x2() as f32,
|
||||
bbox.y2() as f32,
|
||||
confidence
|
||||
])
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(self.conn.last_insert_rowid())
|
||||
}
|
||||
|
||||
/// Store face landmarks
|
||||
pub fn store_landmarks(&self, face_id: i64, landmarks: &FaceLandmarks) -> Result<i64> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#"
|
||||
INSERT INTO landmarks
|
||||
(face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y,
|
||||
nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)
|
||||
"#,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
stmt.execute(params![
|
||||
face_id,
|
||||
landmarks.left_eye.x,
|
||||
landmarks.left_eye.y,
|
||||
landmarks.right_eye.x,
|
||||
landmarks.right_eye.y,
|
||||
landmarks.nose.x,
|
||||
landmarks.nose.y,
|
||||
landmarks.left_mouth.x,
|
||||
landmarks.left_mouth.y,
|
||||
landmarks.right_mouth.x,
|
||||
landmarks.right_mouth.y,
|
||||
])
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(self.conn.last_insert_rowid())
|
||||
}
|
||||
|
||||
/// Store face embeddings
|
||||
pub fn store_embeddings(
|
||||
&self,
|
||||
face_ids: &[i64],
|
||||
embeddings: &[ndarray::Array2<f32>],
|
||||
model_name: &str,
|
||||
) -> Result<Vec<i64>> {
|
||||
let mut embedding_ids = Vec::new();
|
||||
|
||||
for (face_idx, embedding_batch) in embeddings.iter().enumerate() {
|
||||
for (batch_idx, embedding_row) in embedding_batch.rows().into_iter().enumerate() {
|
||||
let global_idx = face_idx * embedding_batch.nrows() + batch_idx;
|
||||
|
||||
if global_idx >= face_ids.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
let face_id = face_ids[global_idx];
|
||||
let embedding_id =
|
||||
self.store_single_embedding(face_id, embedding_row, model_name)?;
|
||||
embedding_ids.push(embedding_id);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embedding_ids)
|
||||
}
|
||||
|
||||
/// Store a single embedding
|
||||
pub fn store_single_embedding(
|
||||
&self,
|
||||
face_id: i64,
|
||||
embedding: ndarray::ArrayView1<f32>,
|
||||
model_name: &str,
|
||||
) -> Result<i64> {
|
||||
// Convert f32 slice to bytes
|
||||
// let embedding_bytes: Vec<u8> = embedding.iter().flat_map(|&f| f.to_le_bytes()).collect();
|
||||
let embedding_bytes = ndarray_safetensors::SafeArrays::new();
|
||||
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)")
|
||||
.change_context(Error)?;
|
||||
|
||||
stmt.execute(params![face_id, embedding_bytes, model_name])
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(self.conn.last_insert_rowid())
|
||||
}
|
||||
|
||||
/// Get image by ID
|
||||
pub fn get_image(&self, image_id: i64) -> Result<Option<ImageRecord>> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("SELECT id, file_path, width, height, created_at FROM images WHERE id = ?1")
|
||||
.change_context(Error)?;
|
||||
|
||||
let result = stmt
|
||||
.query_row(params![image_id], |row| {
|
||||
Ok(ImageRecord {
|
||||
id: row.get(0)?,
|
||||
file_path: row.get(1)?,
|
||||
width: row.get(2)?,
|
||||
height: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})
|
||||
.optional()
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get all faces for an image
|
||||
pub fn get_faces_for_image(&self, image_id: i64) -> Result<Vec<FaceRecord>> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#"
|
||||
SELECT id, image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence, created_at
|
||||
FROM faces WHERE image_id = ?1
|
||||
"#,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
let face_iter = stmt
|
||||
.query_map(params![image_id], |row| {
|
||||
Ok(FaceRecord {
|
||||
id: row.get(0)?,
|
||||
image_id: row.get(1)?,
|
||||
bbox_x1: row.get(2)?,
|
||||
bbox_y1: row.get(3)?,
|
||||
bbox_x2: row.get(4)?,
|
||||
bbox_y2: row.get(5)?,
|
||||
confidence: row.get(6)?,
|
||||
created_at: row.get(7)?,
|
||||
})
|
||||
})
|
||||
.change_context(Error)?;
|
||||
|
||||
let mut faces = Vec::new();
|
||||
for face in face_iter {
|
||||
faces.push(face.change_context(Error)?);
|
||||
}
|
||||
|
||||
Ok(faces)
|
||||
}
|
||||
|
||||
/// Get landmarks for a face
|
||||
pub fn get_landmarks(&self, face_id: i64) -> Result<Option<LandmarkRecord>> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#"
|
||||
SELECT id, face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y,
|
||||
nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y
|
||||
FROM landmarks WHERE face_id = ?1
|
||||
"#,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
let result = stmt
|
||||
.query_row(params![face_id], |row| {
|
||||
Ok(LandmarkRecord {
|
||||
id: row.get(0)?,
|
||||
face_id: row.get(1)?,
|
||||
left_eye_x: row.get(2)?,
|
||||
left_eye_y: row.get(3)?,
|
||||
right_eye_x: row.get(4)?,
|
||||
right_eye_y: row.get(5)?,
|
||||
nose_x: row.get(6)?,
|
||||
nose_y: row.get(7)?,
|
||||
left_mouth_x: row.get(8)?,
|
||||
left_mouth_y: row.get(9)?,
|
||||
right_mouth_x: row.get(10)?,
|
||||
right_mouth_y: row.get(11)?,
|
||||
})
|
||||
})
|
||||
.optional()
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get embeddings for a face
|
||||
pub fn get_embeddings(&self, face_id: i64) -> Result<Vec<EmbeddingRecord>> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
"SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE face_id = ?1",
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
let embedding_iter = stmt
|
||||
.query_map(params![face_id], |row| {
|
||||
let embedding_bytes: Vec<u8> = row.get(2)?;
|
||||
let embedding: Vec<f32> = embedding_bytes
|
||||
.chunks_exact(4)
|
||||
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
|
||||
.collect();
|
||||
|
||||
Ok(EmbeddingRecord {
|
||||
id: row.get(0)?,
|
||||
face_id: row.get(1)?,
|
||||
embedding,
|
||||
model_name: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})
|
||||
.change_context(Error)?;
|
||||
|
||||
let mut embeddings = Vec::new();
|
||||
for embedding in embedding_iter {
|
||||
embeddings.push(embedding.change_context(Error)?);
|
||||
}
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
/// Find similar faces by embedding (using cosine similarity)
|
||||
pub fn find_similar_faces(
|
||||
&self,
|
||||
query_embedding: &[f32],
|
||||
threshold: f32,
|
||||
limit: usize,
|
||||
) -> Result<Vec<(i64, f32)>> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("SELECT face_id, embedding FROM embeddings")
|
||||
.change_context(Error)?;
|
||||
|
||||
let embedding_iter = stmt
|
||||
.query_map([], |row| {
|
||||
let face_id: i64 = row.get(0)?;
|
||||
let embedding_bytes: Vec<u8> = row.get(1)?;
|
||||
let embedding: Vec<f32> = embedding_bytes
|
||||
.chunks_exact(4)
|
||||
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
|
||||
.collect();
|
||||
Ok((face_id, embedding))
|
||||
})
|
||||
.change_context(Error)?;
|
||||
|
||||
let mut similarities = Vec::new();
|
||||
for result in embedding_iter {
|
||||
let (face_id, embedding) = result.change_context(Error)?;
|
||||
let similarity = cosine_similarity(query_embedding, &embedding);
|
||||
if similarity >= threshold {
|
||||
similarities.push((face_id, similarity));
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by similarity (descending) and limit results
|
||||
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
similarities.truncate(limit);
|
||||
|
||||
Ok(similarities)
|
||||
}
|
||||
|
||||
/// Get database statistics
|
||||
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
|
||||
let images: usize = self
|
||||
.conn
|
||||
.query_row("SELECT COUNT(*) FROM images", [], |row| row.get(0))
|
||||
.change_context(Error)?;
|
||||
|
||||
let faces: usize = self
|
||||
.conn
|
||||
.query_row("SELECT COUNT(*) FROM faces", [], |row| row.get(0))
|
||||
.change_context(Error)?;
|
||||
|
||||
let landmarks: usize = self
|
||||
.conn
|
||||
.query_row("SELECT COUNT(*) FROM landmarks", [], |row| row.get(0))
|
||||
.change_context(Error)?;
|
||||
|
||||
let embeddings: usize = self
|
||||
.conn
|
||||
.query_row("SELECT COUNT(*) FROM embeddings", [], |row| row.get(0))
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok((images, faces, landmarks, embeddings))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_database_creation() -> Result<()> {
|
||||
let db = FaceDatabase::in_memory()?;
|
||||
let (images, faces, landmarks, embeddings) = db.get_stats()?;
|
||||
assert_eq!(images, 0);
|
||||
assert_eq!(faces, 0);
|
||||
assert_eq!(landmarks, 0);
|
||||
assert_eq!(embeddings, 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_store_and_retrieve_image() -> Result<()> {
|
||||
let db = FaceDatabase::in_memory()?;
|
||||
let image_id = db.store_image("/path/to/image.jpg", 800, 600)?;
|
||||
|
||||
let image = db.get_image(image_id)?.unwrap();
|
||||
assert_eq!(image.file_path, "/path/to/image.jpg");
|
||||
assert_eq!(image.width, 800);
|
||||
assert_eq!(image.height, 600);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod database;
|
||||
pub mod errors;
|
||||
pub mod facedet;
|
||||
pub mod faceembed;
|
||||
|
||||
198
src/main.rs
198
src/main.rs
@@ -1,9 +1,10 @@
|
||||
mod cli;
|
||||
mod errors;
|
||||
use bounding_box::roi::MultiRoi;
|
||||
use detector::{facedet, facedet::FaceDetectionConfig, faceembed};
|
||||
use detector::{database::FaceDatabase, facedet, facedet::FaceDetectionConfig, faceembed};
|
||||
use errors::*;
|
||||
use fast_image_resize::ResizeOptions;
|
||||
|
||||
use ndarray::*;
|
||||
use ndarray_image::*;
|
||||
use ndarray_resize::NdFir;
|
||||
@@ -77,6 +78,15 @@ pub fn main() -> Result<()> {
|
||||
cli::SubCommand::List(list) => {
|
||||
println!("List: {:?}", list);
|
||||
}
|
||||
cli::SubCommand::Query(query) => {
|
||||
run_query(query)?;
|
||||
}
|
||||
cli::SubCommand::Similar(similar) => {
|
||||
run_similar(similar)?;
|
||||
}
|
||||
cli::SubCommand::Stats(stats) => {
|
||||
run_stats(stats)?;
|
||||
}
|
||||
cli::SubCommand::Completions { shell } => {
|
||||
cli::Cli::completions(shell);
|
||||
}
|
||||
@@ -89,10 +99,22 @@ where
|
||||
D: facedet::FaceDetector,
|
||||
E: faceembed::FaceEmbedder,
|
||||
{
|
||||
// Initialize database if requested
|
||||
let db = if detect.save_to_db {
|
||||
let db_path = detect
|
||||
.database
|
||||
.as_ref()
|
||||
.map(|p| p.as_path())
|
||||
.unwrap_or_else(|| std::path::Path::new("face_detections.db"));
|
||||
Some(FaceDatabase::new(db_path).change_context(Error)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let image = image::open(&detect.image)
|
||||
.change_context(Error)
|
||||
.attach_printable(detect.image.to_string_lossy().to_string())?;
|
||||
let image = image.into_rgb8();
|
||||
let (image_width, image_height) = image.dimensions();
|
||||
let mut array = image
|
||||
.into_ndarray()
|
||||
.change_context(errors::Error)
|
||||
@@ -106,6 +128,26 @@ where
|
||||
)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
|
||||
// Store image and face detections in database if requested
|
||||
let (image_id, face_ids) = if let Some(ref database) = db {
|
||||
let image_path = detect.image.to_string_lossy();
|
||||
let img_id = database
|
||||
.store_image(&image_path, image_width, image_height)
|
||||
.change_context(Error)?;
|
||||
let face_ids = database
|
||||
.store_face_detections(img_id, &output)
|
||||
.change_context(Error)?;
|
||||
tracing::info!(
|
||||
"Stored image {} with {} faces in database",
|
||||
img_id,
|
||||
face_ids.len()
|
||||
);
|
||||
(Some(img_id), Some(face_ids))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
for bbox in &output.bbox {
|
||||
tracing::info!("Detected face: {:?}", bbox);
|
||||
use bounding_box::draw::*;
|
||||
@@ -159,6 +201,25 @@ where
|
||||
})
|
||||
.collect::<Result<Vec<Array2<f32>>>>()?;
|
||||
|
||||
// Store embeddings in database if requested
|
||||
if let (Some(database), Some(face_ids)) = (&db, &face_ids) {
|
||||
let embedding_ids = database
|
||||
.store_embeddings(face_ids, &embeddings, &detect.model_name)
|
||||
.change_context(Error)?;
|
||||
tracing::info!("Stored {} embeddings in database", embedding_ids.len());
|
||||
|
||||
// Print database statistics
|
||||
let (num_images, num_faces, num_landmarks, num_embeddings) =
|
||||
database.get_stats().change_context(Error)?;
|
||||
tracing::info!(
|
||||
"Database stats - Images: {}, Faces: {}, Landmarks: {}, Embeddings: {}",
|
||||
num_images,
|
||||
num_faces,
|
||||
num_landmarks,
|
||||
num_embeddings
|
||||
);
|
||||
}
|
||||
|
||||
let v = array.view();
|
||||
if let Some(output) = detect.output {
|
||||
let image: image::RgbImage = v
|
||||
@@ -173,3 +234,138 @@ where
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_query(query: cli::Query) -> Result<()> {
|
||||
let db = FaceDatabase::new(&query.database).change_context(Error)?;
|
||||
|
||||
if let Some(image_id) = query.image_id {
|
||||
if let Some(image) = db.get_image(image_id).change_context(Error)? {
|
||||
println!("Image: {}", image.file_path);
|
||||
println!("Dimensions: {}x{}", image.width, image.height);
|
||||
println!("Created: {}", image.created_at);
|
||||
|
||||
let faces = db.get_faces_for_image(image_id).change_context(Error)?;
|
||||
println!("Faces found: {}", faces.len());
|
||||
|
||||
for face in faces {
|
||||
println!(
|
||||
" Face ID {}: bbox({:.1}, {:.1}, {:.1}, {:.1}), confidence: {:.3}",
|
||||
face.id,
|
||||
face.bbox_x1,
|
||||
face.bbox_y1,
|
||||
face.bbox_x2,
|
||||
face.bbox_y2,
|
||||
face.confidence
|
||||
);
|
||||
|
||||
if query.show_landmarks {
|
||||
if let Some(landmarks) = db.get_landmarks(face.id).change_context(Error)? {
|
||||
println!(
|
||||
" Landmarks: left_eye({:.1}, {:.1}), right_eye({:.1}, {:.1}), nose({:.1}, {:.1})",
|
||||
landmarks.left_eye_x,
|
||||
landmarks.left_eye_y,
|
||||
landmarks.right_eye_x,
|
||||
landmarks.right_eye_y,
|
||||
landmarks.nose_x,
|
||||
landmarks.nose_y
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if query.show_embeddings {
|
||||
let embeddings = db.get_embeddings(face.id).change_context(Error)?;
|
||||
for embedding in embeddings {
|
||||
println!(
|
||||
" Embedding ({}): {} dims, model: {}",
|
||||
embedding.id,
|
||||
embedding.embedding.len(),
|
||||
embedding.model_name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!("Image with ID {} not found", image_id);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(face_id) = query.face_id {
|
||||
if let Some(landmarks) = db.get_landmarks(face_id).change_context(Error)? {
|
||||
println!(
|
||||
"Landmarks for face {}: left_eye({:.1}, {:.1}), right_eye({:.1}, {:.1}), nose({:.1}, {:.1})",
|
||||
face_id,
|
||||
landmarks.left_eye_x,
|
||||
landmarks.left_eye_y,
|
||||
landmarks.right_eye_x,
|
||||
landmarks.right_eye_y,
|
||||
landmarks.nose_x,
|
||||
landmarks.nose_y
|
||||
);
|
||||
} else {
|
||||
println!("No landmarks found for face {}", face_id);
|
||||
}
|
||||
|
||||
let embeddings = db.get_embeddings(face_id).change_context(Error)?;
|
||||
println!(
|
||||
"Embeddings for face {}: {} found",
|
||||
face_id,
|
||||
embeddings.len()
|
||||
);
|
||||
for embedding in embeddings {
|
||||
println!(
|
||||
" Embedding {}: {} dims, model: {}, created: {}",
|
||||
embedding.id,
|
||||
embedding.embedding.len(),
|
||||
embedding.model_name,
|
||||
embedding.created_at
|
||||
);
|
||||
if query.show_embeddings {
|
||||
println!(
|
||||
" Values: {:?}",
|
||||
&embedding.embedding[..embedding.embedding.len().min(10)]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_similar(similar: cli::Similar) -> Result<()> {
|
||||
let db = FaceDatabase::new(&similar.database).change_context(Error)?;
|
||||
|
||||
let embeddings = db.get_embeddings(similar.face_id).change_context(Error)?;
|
||||
if embeddings.is_empty() {
|
||||
println!("No embeddings found for face {}", similar.face_id);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let query_embedding = &embeddings[0].embedding;
|
||||
let similar_faces = db
|
||||
.find_similar_faces(query_embedding, similar.threshold, similar.limit)
|
||||
.change_context(Error)?;
|
||||
|
||||
println!(
|
||||
"Found {} similar faces (threshold: {:.3}):",
|
||||
similar_faces.len(),
|
||||
similar.threshold
|
||||
);
|
||||
for (face_id, similarity) in similar_faces {
|
||||
println!(" Face {}: similarity {:.3}", face_id, similarity);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_stats(stats: cli::Stats) -> Result<()> {
|
||||
let db = FaceDatabase::new(&stats.database).change_context(Error)?;
|
||||
let (images, faces, landmarks, embeddings) = db.get_stats().change_context(Error)?;
|
||||
|
||||
println!("Database Statistics:");
|
||||
println!(" Images: {}", images);
|
||||
println!(" Faces: {}", faces);
|
||||
println!(" Landmarks: {}", landmarks);
|
||||
println!(" Embeddings: {}", embeddings);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user