use crate::errors::{Error, Result}; use crate::facedet::{FaceDetectionOutput, FaceLandmarks}; use bounding_box::Aabb2; use error_stack::ResultExt; use ndarray_math::CosineSimilarity; use rusqlite::{Connection, OptionalExtension, params}; use std::path::Path; /// Database connection and operations for face detection results pub struct FaceDatabase { conn: Connection, } /// Represents a stored image record #[derive(Debug, Clone)] pub struct ImageRecord { pub id: i64, pub file_path: String, pub width: u32, pub height: u32, pub created_at: String, } /// Represents a stored face detection record #[derive(Debug, Clone)] pub struct FaceRecord { pub id: i64, pub image_id: i64, pub bbox_x1: f32, pub bbox_y1: f32, pub bbox_x2: f32, pub bbox_y2: f32, pub confidence: f32, pub created_at: String, } /// Represents stored face landmarks #[derive(Debug, Clone)] pub struct LandmarkRecord { pub id: i64, pub face_id: i64, pub left_eye_x: f32, pub left_eye_y: f32, pub right_eye_x: f32, pub right_eye_y: f32, pub nose_x: f32, pub nose_y: f32, pub left_mouth_x: f32, pub left_mouth_y: f32, pub right_mouth_x: f32, pub right_mouth_y: f32, } /// Represents a stored face embedding #[derive(Debug, Clone)] pub struct EmbeddingRecord { pub id: i64, pub face_id: i64, pub embedding: ndarray::Array1, pub model_name: String, pub created_at: String, } impl FaceDatabase { /// Create a new database connection and initialize tables pub fn new>(db_path: P) -> Result { let conn = Connection::open(db_path).change_context(Error)?; // Temporarily disable extension loading for clustering demo // unsafe { // let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?; // conn.load_extension( // "/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib", // None::<&str>, // ) // .change_context(Error)?; // } let db = Self { conn }; db.create_tables()?; Ok(db) } /// Create an in-memory database for testing pub fn in_memory() -> Result { let conn = Connection::open_in_memory().change_context(Error)?; let db = Self { conn }; db.create_tables()?; Ok(db) } /// Create all necessary database tables fn create_tables(&self) -> Result<()> { // Images table self.conn .execute( r#" CREATE TABLE IF NOT EXISTS images ( id INTEGER PRIMARY KEY AUTOINCREMENT, file_path TEXT NOT NULL UNIQUE, width INTEGER NOT NULL, height INTEGER NOT NULL, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ) "#, [], ) .change_context(Error)?; // Faces table self.conn .execute( r#" CREATE TABLE IF NOT EXISTS faces ( id INTEGER PRIMARY KEY AUTOINCREMENT, image_id INTEGER NOT NULL, bbox_x1 REAL NOT NULL, bbox_y1 REAL NOT NULL, bbox_x2 REAL NOT NULL, bbox_y2 REAL NOT NULL, confidence REAL NOT NULL, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE ) "#, [], ) .change_context(Error)?; // Landmarks table self.conn .execute( r#" CREATE TABLE IF NOT EXISTS landmarks ( id INTEGER PRIMARY KEY AUTOINCREMENT, face_id INTEGER NOT NULL, left_eye_x REAL NOT NULL, left_eye_y REAL NOT NULL, right_eye_x REAL NOT NULL, right_eye_y REAL NOT NULL, nose_x REAL NOT NULL, nose_y REAL NOT NULL, left_mouth_x REAL NOT NULL, left_mouth_y REAL NOT NULL, right_mouth_x REAL NOT NULL, right_mouth_y REAL NOT NULL, FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE ) "#, [], ) .change_context(Error)?; // Embeddings table self.conn .execute( r#" CREATE TABLE IF NOT EXISTS embeddings ( id INTEGER PRIMARY KEY AUTOINCREMENT, face_id INTEGER NOT NULL, embedding BLOB NOT NULL, model_name TEXT NOT NULL, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE ) "#, [], ) .change_context(Error)?; // Create indexes for better performance self.conn .execute( "CREATE INDEX IF NOT EXISTS idx_faces_image_id ON faces (image_id)", [], ) .change_context(Error)?; self.conn .execute( "CREATE INDEX IF NOT EXISTS idx_landmarks_face_id ON landmarks (face_id)", [], ) .change_context(Error)?; self.conn .execute( "CREATE INDEX IF NOT EXISTS idx_embeddings_face_id ON embeddings (face_id)", [], ) .change_context(Error)?; Ok(()) } /// Store image metadata and return the image ID pub fn store_image(&self, file_path: &str, width: u32, height: u32) -> Result { let mut stmt = self .conn .prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)") .change_context(Error)?; Ok(stmt .insert(params![file_path, width, height]) .change_context(Error)?) } /// Store face detection results pub fn store_face_detections( &self, image_id: i64, detection_output: &FaceDetectionOutput, ) -> Result> { let mut face_ids = Vec::new(); for (i, bbox) in detection_output.bbox.iter().enumerate() { let confidence = detection_output.confidence.get(i).copied().unwrap_or(0.0); let face_id = self.store_face(image_id, bbox, confidence)?; face_ids.push(face_id); // Store landmarks if available if let Some(landmarks) = detection_output.landmark.get(i) { self.store_landmarks(face_id, landmarks)?; } } Ok(face_ids) } /// Store a single face detection pub fn store_face(&self, image_id: i64, bbox: &Aabb2, confidence: f32) -> Result { let mut stmt = self .conn .prepare( r#" INSERT INTO faces (image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence) VALUES (?1, ?2, ?3, ?4, ?5, ?6) "#, ) .change_context(Error)?; Ok(stmt .insert(params![ image_id, bbox.x1() as f32, bbox.y1() as f32, bbox.x2() as f32, bbox.y2() as f32, confidence ]) .change_context(Error)?) } /// Store face landmarks pub fn store_landmarks(&self, face_id: i64, landmarks: &FaceLandmarks) -> Result { let mut stmt = self .conn .prepare( r#" INSERT INTO landmarks (face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y, nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11) "#, ) .change_context(Error)?; Ok(stmt .insert(params![ face_id, landmarks.left_eye.x, landmarks.left_eye.y, landmarks.right_eye.x, landmarks.right_eye.y, landmarks.nose.x, landmarks.nose.y, landmarks.left_mouth.x, landmarks.left_mouth.y, landmarks.right_mouth.x, landmarks.right_mouth.y, ]) .change_context(Error)?) } /// Store face embeddings pub fn store_embeddings( &self, face_ids: &[i64], embeddings: &[ndarray::Array2], model_name: &str, ) -> Result> { let mut embedding_ids = Vec::new(); for (face_idx, embedding_batch) in embeddings.iter().enumerate() { for (batch_idx, embedding_row) in embedding_batch.rows().into_iter().enumerate() { let global_idx = face_idx * embedding_batch.nrows() + batch_idx; if global_idx >= face_ids.len() { break; } let face_id = face_ids[global_idx]; let embedding_id = self.store_single_embedding(face_id, embedding_row, model_name)?; embedding_ids.push(embedding_id); } } Ok(embedding_ids) } /// Store a single embedding pub fn store_single_embedding( &self, face_id: i64, embedding: ndarray::ArrayView1, model_name: &str, ) -> Result { let safe_arrays = ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)]) .change_context(Error)?; let embedding_bytes = safe_arrays.serialize().change_context(Error)?; let mut stmt = self .conn .prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)") .change_context(Error)?; stmt.execute(params![face_id, embedding_bytes, model_name]) .change_context(Error)?; Ok(self.conn.last_insert_rowid()) } /// Get image by ID pub fn get_image(&self, image_id: i64) -> Result> { let mut stmt = self .conn .prepare("SELECT id, file_path, width, height, created_at FROM images WHERE id = ?1") .change_context(Error)?; let result = stmt .query_row(params![image_id], |row| { Ok(ImageRecord { id: row.get(0)?, file_path: row.get(1)?, width: row.get(2)?, height: row.get(3)?, created_at: row.get(4)?, }) }) .optional() .change_context(Error)?; Ok(result) } /// Get all faces for an image pub fn get_faces_for_image(&self, image_id: i64) -> Result> { let mut stmt = self .conn .prepare( r#" SELECT id, image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence, created_at FROM faces WHERE image_id = ?1 "#, ) .change_context(Error)?; let face_iter = stmt .query_map(params![image_id], |row| { Ok(FaceRecord { id: row.get(0)?, image_id: row.get(1)?, bbox_x1: row.get(2)?, bbox_y1: row.get(3)?, bbox_x2: row.get(4)?, bbox_y2: row.get(5)?, confidence: row.get(6)?, created_at: row.get(7)?, }) }) .change_context(Error)?; let mut faces = Vec::new(); for face in face_iter { faces.push(face.change_context(Error)?); } Ok(faces) } /// Get landmarks for a face pub fn get_landmarks(&self, face_id: i64) -> Result> { let mut stmt = self .conn .prepare( r#" SELECT id, face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y, nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y FROM landmarks WHERE face_id = ?1 "#, ) .change_context(Error)?; let result = stmt .query_row(params![face_id], |row| { Ok(LandmarkRecord { id: row.get(0)?, face_id: row.get(1)?, left_eye_x: row.get(2)?, left_eye_y: row.get(3)?, right_eye_x: row.get(4)?, right_eye_y: row.get(5)?, nose_x: row.get(6)?, nose_y: row.get(7)?, left_mouth_x: row.get(8)?, left_mouth_y: row.get(9)?, right_mouth_x: row.get(10)?, right_mouth_y: row.get(11)?, }) }) .optional() .change_context(Error)?; Ok(result) } /// Get embeddings for a face pub fn get_embeddings(&self, face_id: i64) -> Result> { let mut stmt = self .conn .prepare( "SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE face_id = ?1", ) .change_context(Error)?; let embedding_iter = stmt .query_map(params![face_id], |row| { let embedding_bytes: Vec = row.get(2)?; let embedding: ndarray::Array1 = { let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes) .change_context(Error) // .change_context(Error)? .unwrap(); sf.tensor::("embedding") // .change_context(Error)? .unwrap() .to_owned() }; Ok(EmbeddingRecord { id: row.get(0)?, face_id: row.get(1)?, embedding, model_name: row.get(3)?, created_at: row.get(4)?, }) }) .change_context(Error)?; let mut embeddings = Vec::new(); for embedding in embedding_iter { embeddings.push(embedding.change_context(Error)?); } Ok(embeddings) } pub fn get_image_for_face(&self, face_id: i64) -> Result> { let mut stmt = self .conn .prepare( r#" SELECT images.id, images.file_path, images.width, images.height, images.created_at FROM images JOIN faces ON faces.image_id = images.id WHERE faces.id = ?1 "#, ) .change_context(Error)?; let result = stmt .query_row(params![face_id], |row| { Ok(ImageRecord { id: row.get(0)?, file_path: row.get(1)?, width: row.get(2)?, height: row.get(3)?, created_at: row.get(4)?, }) }) .optional() .change_context(Error)?; Ok(result) } /// Get database statistics pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> { let images: usize = self .conn .query_row("SELECT COUNT(*) FROM images", [], |row| row.get(0)) .change_context(Error)?; let faces: usize = self .conn .query_row("SELECT COUNT(*) FROM faces", [], |row| row.get(0)) .change_context(Error)?; let landmarks: usize = self .conn .query_row("SELECT COUNT(*) FROM landmarks", [], |row| row.get(0)) .change_context(Error)?; let embeddings: usize = self .conn .query_row("SELECT COUNT(*) FROM embeddings", [], |row| row.get(0)) .change_context(Error)?; Ok((images, faces, landmarks, embeddings)) } /// Find similar faces based on cosine similarity of embeddings /// Return ids and similarity scores of similar faces pub fn find_similar_faces( &self, embedding: &ndarray::Array1, threshold: f32, limit: usize, ) -> Result> { // Serialize the query embedding to bytes let embedding_bytes = ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())]) .change_context(Error)? .serialize() .change_context(Error)?; let mut stmt = self .conn .prepare( r#" SELECT face_id, cosine_similarity(?1, embedding) as similarity FROM embeddings WHERE cosine_similarity(?1, embedding) >= ?2 ORDER BY similarity DESC LIMIT ?3"#, ) .change_context(Error)?; let result = stmt .query_map(params![embedding_bytes, threshold, limit], |row| { Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?)) }) .change_context(Error)? .map(|r| r.change_context(Error)) .collect::>>()?; // let mut results = Vec::new(); // for result in result_iter { // results.push(result.change_context(Error)?); // } Ok(result) } pub fn query_similarity(&self, embedding: &ndarray::Array1) { let embedding_bytes = ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())]) .change_context(Error) .unwrap() .serialize() .change_context(Error) .unwrap(); let mut stmt = self .conn .prepare( r#" SELECT face_id, cosine_similarity(?1, embedding) FROM embeddings "#, ) .change_context(Error) .unwrap(); let result_iter = stmt .query_map(params![embedding_bytes], |row| { Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?)) }) .change_context(Error) .unwrap(); for result in result_iter { println!("{:?}", result); } } /// Get all embeddings for a specific model pub fn get_all_embeddings(&self, model_name: Option<&str>) -> Result> { let mut embeddings = Vec::new(); if let Some(model) = model_name { let mut stmt = self.conn.prepare( "SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE model_name = ?1" ).change_context(Error)?; let embedding_iter = stmt .query_map(params![model], |row| { let embedding_bytes: Vec = row.get(2)?; let embedding: ndarray::Array1 = { let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes) .change_context(Error) .unwrap(); sf.tensor::("embedding") .unwrap() .to_owned() }; Ok(EmbeddingRecord { id: row.get(0)?, face_id: row.get(1)?, embedding, model_name: row.get(3)?, created_at: row.get(4)?, }) }) .change_context(Error)?; for embedding in embedding_iter { embeddings.push(embedding.change_context(Error)?); } } else { let mut stmt = self .conn .prepare("SELECT id, face_id, embedding, model_name, created_at FROM embeddings") .change_context(Error)?; let embedding_iter = stmt .query_map([], |row| { let embedding_bytes: Vec = row.get(2)?; let embedding: ndarray::Array1 = { let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes) .change_context(Error) .unwrap(); sf.tensor::("embedding") .unwrap() .to_owned() }; Ok(EmbeddingRecord { id: row.get(0)?, face_id: row.get(1)?, embedding, model_name: row.get(3)?, created_at: row.get(4)?, }) }) .change_context(Error)?; for embedding in embedding_iter { embeddings.push(embedding.change_context(Error)?); } } Ok(embeddings) } } fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> { use rusqlite::functions::*; db.create_scalar_function( "cosine_similarity", 2, FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, move |ctx| { if ctx.len() != 2 { return Err(rusqlite::Error::UserFunctionError( "cosine_similarity requires exactly 2 arguments".into(), )); } let array_1 = ctx.get_raw(0).as_blob()?; let array_2 = ctx.get_raw(1).as_blob()?; let array_1_st = ndarray_safetensors::SafeArraysView::from_bytes(array_1) .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?; let array_2_st = ndarray_safetensors::SafeArraysView::from_bytes(array_2) .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?; let array_view_1 = array_1_st .tensor_by_index::(0) .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?; let array_view_2 = array_2_st .tensor_by_index::(0) .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?; let similarity = array_view_1 .cosine_similarity(array_view_2) .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?; Ok(similarity) }, ) .change_context(Error) } #[cfg(test)] mod tests { use super::*; #[test] fn test_database_creation() -> Result<()> { let db = FaceDatabase::in_memory()?; let (images, faces, landmarks, embeddings) = db.get_stats()?; assert_eq!(images, 0); assert_eq!(faces, 0); assert_eq!(landmarks, 0); assert_eq!(embeddings, 0); Ok(()) } #[test] fn test_store_and_retrieve_image() -> Result<()> { let db = FaceDatabase::in_memory()?; let image_id = db.store_image("/path/to/image.jpg", 800, 600)?; let image = db.get_image(image_id)?.unwrap(); assert_eq!(image.file_path, "/path/to/image.jpg"); assert_eq!(image.width, 800); assert_eq!(image.height, 600); Ok(()) } }