Files
face-detector/src/database.rs
uttarayan21 a340552257
Some checks failed
build / checks-matrix (push) Successful in 19m25s
build / codecov (push) Failing after 19m26s
docs / docs (push) Failing after 28m52s
build / checks-build (push) Has been cancelled
feat(cli): add clustering command with K-means support
2025-09-13 17:45:55 +05:30

735 lines
24 KiB
Rust

use crate::errors::{Error, Result};
use crate::facedet::{FaceDetectionOutput, FaceLandmarks};
use bounding_box::Aabb2;
use error_stack::ResultExt;
use ndarray_math::CosineSimilarity;
use rusqlite::{Connection, OptionalExtension, params};
use std::path::Path;
/// Database connection and operations for face detection results
pub struct FaceDatabase {
conn: Connection,
}
/// Represents a stored image record
#[derive(Debug, Clone)]
pub struct ImageRecord {
pub id: i64,
pub file_path: String,
pub width: u32,
pub height: u32,
pub created_at: String,
}
/// Represents a stored face detection record
#[derive(Debug, Clone)]
pub struct FaceRecord {
pub id: i64,
pub image_id: i64,
pub bbox_x1: f32,
pub bbox_y1: f32,
pub bbox_x2: f32,
pub bbox_y2: f32,
pub confidence: f32,
pub created_at: String,
}
/// Represents stored face landmarks
#[derive(Debug, Clone)]
pub struct LandmarkRecord {
pub id: i64,
pub face_id: i64,
pub left_eye_x: f32,
pub left_eye_y: f32,
pub right_eye_x: f32,
pub right_eye_y: f32,
pub nose_x: f32,
pub nose_y: f32,
pub left_mouth_x: f32,
pub left_mouth_y: f32,
pub right_mouth_x: f32,
pub right_mouth_y: f32,
}
/// Represents a stored face embedding
#[derive(Debug, Clone)]
pub struct EmbeddingRecord {
pub id: i64,
pub face_id: i64,
pub embedding: ndarray::Array1<f32>,
pub model_name: String,
pub created_at: String,
}
impl FaceDatabase {
/// Create a new database connection and initialize tables
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
let conn = Connection::open(db_path).change_context(Error)?;
// Temporarily disable extension loading for clustering demo
// unsafe {
// let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
// conn.load_extension(
// "/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
// None::<&str>,
// )
// .change_context(Error)?;
// }
let db = Self { conn };
db.create_tables()?;
Ok(db)
}
/// Create an in-memory database for testing
pub fn in_memory() -> Result<Self> {
let conn = Connection::open_in_memory().change_context(Error)?;
let db = Self { conn };
db.create_tables()?;
Ok(db)
}
/// Create all necessary database tables
fn create_tables(&self) -> Result<()> {
// Images table
self.conn
.execute(
r#"
CREATE TABLE IF NOT EXISTS images (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_path TEXT NOT NULL UNIQUE,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"#,
[],
)
.change_context(Error)?;
// Faces table
self.conn
.execute(
r#"
CREATE TABLE IF NOT EXISTS faces (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_id INTEGER NOT NULL,
bbox_x1 REAL NOT NULL,
bbox_y1 REAL NOT NULL,
bbox_x2 REAL NOT NULL,
bbox_y2 REAL NOT NULL,
confidence REAL NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
)
"#,
[],
)
.change_context(Error)?;
// Landmarks table
self.conn
.execute(
r#"
CREATE TABLE IF NOT EXISTS landmarks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
face_id INTEGER NOT NULL,
left_eye_x REAL NOT NULL,
left_eye_y REAL NOT NULL,
right_eye_x REAL NOT NULL,
right_eye_y REAL NOT NULL,
nose_x REAL NOT NULL,
nose_y REAL NOT NULL,
left_mouth_x REAL NOT NULL,
left_mouth_y REAL NOT NULL,
right_mouth_x REAL NOT NULL,
right_mouth_y REAL NOT NULL,
FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE
)
"#,
[],
)
.change_context(Error)?;
// Embeddings table
self.conn
.execute(
r#"
CREATE TABLE IF NOT EXISTS embeddings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
face_id INTEGER NOT NULL,
embedding BLOB NOT NULL,
model_name TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE
)
"#,
[],
)
.change_context(Error)?;
// Create indexes for better performance
self.conn
.execute(
"CREATE INDEX IF NOT EXISTS idx_faces_image_id ON faces (image_id)",
[],
)
.change_context(Error)?;
self.conn
.execute(
"CREATE INDEX IF NOT EXISTS idx_landmarks_face_id ON landmarks (face_id)",
[],
)
.change_context(Error)?;
self.conn
.execute(
"CREATE INDEX IF NOT EXISTS idx_embeddings_face_id ON embeddings (face_id)",
[],
)
.change_context(Error)?;
Ok(())
}
/// Store image metadata and return the image ID
pub fn store_image(&self, file_path: &str, width: u32, height: u32) -> Result<i64> {
let mut stmt = self
.conn
.prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)")
.change_context(Error)?;
Ok(stmt
.insert(params![file_path, width, height])
.change_context(Error)?)
}
/// Store face detection results
pub fn store_face_detections(
&self,
image_id: i64,
detection_output: &FaceDetectionOutput,
) -> Result<Vec<i64>> {
let mut face_ids = Vec::new();
for (i, bbox) in detection_output.bbox.iter().enumerate() {
let confidence = detection_output.confidence.get(i).copied().unwrap_or(0.0);
let face_id = self.store_face(image_id, bbox, confidence)?;
face_ids.push(face_id);
// Store landmarks if available
if let Some(landmarks) = detection_output.landmark.get(i) {
self.store_landmarks(face_id, landmarks)?;
}
}
Ok(face_ids)
}
/// Store a single face detection
pub fn store_face(&self, image_id: i64, bbox: &Aabb2<usize>, confidence: f32) -> Result<i64> {
let mut stmt = self
.conn
.prepare(
r#"
INSERT INTO faces (image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
"#,
)
.change_context(Error)?;
Ok(stmt
.insert(params![
image_id,
bbox.x1() as f32,
bbox.y1() as f32,
bbox.x2() as f32,
bbox.y2() as f32,
confidence
])
.change_context(Error)?)
}
/// Store face landmarks
pub fn store_landmarks(&self, face_id: i64, landmarks: &FaceLandmarks) -> Result<i64> {
let mut stmt = self
.conn
.prepare(
r#"
INSERT INTO landmarks
(face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y,
nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)
"#,
)
.change_context(Error)?;
Ok(stmt
.insert(params![
face_id,
landmarks.left_eye.x,
landmarks.left_eye.y,
landmarks.right_eye.x,
landmarks.right_eye.y,
landmarks.nose.x,
landmarks.nose.y,
landmarks.left_mouth.x,
landmarks.left_mouth.y,
landmarks.right_mouth.x,
landmarks.right_mouth.y,
])
.change_context(Error)?)
}
/// Store face embeddings
pub fn store_embeddings(
&self,
face_ids: &[i64],
embeddings: &[ndarray::Array2<f32>],
model_name: &str,
) -> Result<Vec<i64>> {
let mut embedding_ids = Vec::new();
for (face_idx, embedding_batch) in embeddings.iter().enumerate() {
for (batch_idx, embedding_row) in embedding_batch.rows().into_iter().enumerate() {
let global_idx = face_idx * embedding_batch.nrows() + batch_idx;
if global_idx >= face_ids.len() {
break;
}
let face_id = face_ids[global_idx];
let embedding_id =
self.store_single_embedding(face_id, embedding_row, model_name)?;
embedding_ids.push(embedding_id);
}
}
Ok(embedding_ids)
}
/// Store a single embedding
pub fn store_single_embedding(
&self,
face_id: i64,
embedding: ndarray::ArrayView1<f32>,
model_name: &str,
) -> Result<i64> {
let safe_arrays =
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)])
.change_context(Error)?;
let embedding_bytes = safe_arrays.serialize().change_context(Error)?;
let mut stmt = self
.conn
.prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)")
.change_context(Error)?;
stmt.execute(params![face_id, embedding_bytes, model_name])
.change_context(Error)?;
Ok(self.conn.last_insert_rowid())
}
/// Get image by ID
pub fn get_image(&self, image_id: i64) -> Result<Option<ImageRecord>> {
let mut stmt = self
.conn
.prepare("SELECT id, file_path, width, height, created_at FROM images WHERE id = ?1")
.change_context(Error)?;
let result = stmt
.query_row(params![image_id], |row| {
Ok(ImageRecord {
id: row.get(0)?,
file_path: row.get(1)?,
width: row.get(2)?,
height: row.get(3)?,
created_at: row.get(4)?,
})
})
.optional()
.change_context(Error)?;
Ok(result)
}
/// Get all faces for an image
pub fn get_faces_for_image(&self, image_id: i64) -> Result<Vec<FaceRecord>> {
let mut stmt = self
.conn
.prepare(
r#"
SELECT id, image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence, created_at
FROM faces WHERE image_id = ?1
"#,
)
.change_context(Error)?;
let face_iter = stmt
.query_map(params![image_id], |row| {
Ok(FaceRecord {
id: row.get(0)?,
image_id: row.get(1)?,
bbox_x1: row.get(2)?,
bbox_y1: row.get(3)?,
bbox_x2: row.get(4)?,
bbox_y2: row.get(5)?,
confidence: row.get(6)?,
created_at: row.get(7)?,
})
})
.change_context(Error)?;
let mut faces = Vec::new();
for face in face_iter {
faces.push(face.change_context(Error)?);
}
Ok(faces)
}
/// Get landmarks for a face
pub fn get_landmarks(&self, face_id: i64) -> Result<Option<LandmarkRecord>> {
let mut stmt = self
.conn
.prepare(
r#"
SELECT id, face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y,
nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y
FROM landmarks WHERE face_id = ?1
"#,
)
.change_context(Error)?;
let result = stmt
.query_row(params![face_id], |row| {
Ok(LandmarkRecord {
id: row.get(0)?,
face_id: row.get(1)?,
left_eye_x: row.get(2)?,
left_eye_y: row.get(3)?,
right_eye_x: row.get(4)?,
right_eye_y: row.get(5)?,
nose_x: row.get(6)?,
nose_y: row.get(7)?,
left_mouth_x: row.get(8)?,
left_mouth_y: row.get(9)?,
right_mouth_x: row.get(10)?,
right_mouth_y: row.get(11)?,
})
})
.optional()
.change_context(Error)?;
Ok(result)
}
/// Get embeddings for a face
pub fn get_embeddings(&self, face_id: i64) -> Result<Vec<EmbeddingRecord>> {
let mut stmt = self
.conn
.prepare(
"SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE face_id = ?1",
)
.change_context(Error)?;
let embedding_iter = stmt
.query_map(params![face_id], |row| {
let embedding_bytes: Vec<u8> = row.get(2)?;
let embedding: ndarray::Array1<f32> = {
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
.change_context(Error)
// .change_context(Error)?
.unwrap();
sf.tensor::<f32, ndarray::Ix1>("embedding")
// .change_context(Error)?
.unwrap()
.to_owned()
};
Ok(EmbeddingRecord {
id: row.get(0)?,
face_id: row.get(1)?,
embedding,
model_name: row.get(3)?,
created_at: row.get(4)?,
})
})
.change_context(Error)?;
let mut embeddings = Vec::new();
for embedding in embedding_iter {
embeddings.push(embedding.change_context(Error)?);
}
Ok(embeddings)
}
pub fn get_image_for_face(&self, face_id: i64) -> Result<Option<ImageRecord>> {
let mut stmt = self
.conn
.prepare(
r#"
SELECT images.id, images.file_path, images.width, images.height, images.created_at
FROM images
JOIN faces ON faces.image_id = images.id
WHERE faces.id = ?1
"#,
)
.change_context(Error)?;
let result = stmt
.query_row(params![face_id], |row| {
Ok(ImageRecord {
id: row.get(0)?,
file_path: row.get(1)?,
width: row.get(2)?,
height: row.get(3)?,
created_at: row.get(4)?,
})
})
.optional()
.change_context(Error)?;
Ok(result)
}
/// Get database statistics
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
let images: usize = self
.conn
.query_row("SELECT COUNT(*) FROM images", [], |row| row.get(0))
.change_context(Error)?;
let faces: usize = self
.conn
.query_row("SELECT COUNT(*) FROM faces", [], |row| row.get(0))
.change_context(Error)?;
let landmarks: usize = self
.conn
.query_row("SELECT COUNT(*) FROM landmarks", [], |row| row.get(0))
.change_context(Error)?;
let embeddings: usize = self
.conn
.query_row("SELECT COUNT(*) FROM embeddings", [], |row| row.get(0))
.change_context(Error)?;
Ok((images, faces, landmarks, embeddings))
}
/// Find similar faces based on cosine similarity of embeddings
/// Return ids and similarity scores of similar faces
pub fn find_similar_faces(
&self,
embedding: &ndarray::Array1<f32>,
threshold: f32,
limit: usize,
) -> Result<Vec<(i64, f32)>> {
// Serialize the query embedding to bytes
let embedding_bytes =
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
.change_context(Error)?
.serialize()
.change_context(Error)?;
let mut stmt = self
.conn
.prepare(
r#" SELECT face_id, cosine_similarity(?1, embedding) as similarity
FROM embeddings
WHERE cosine_similarity(?1, embedding) >= ?2
ORDER BY similarity DESC
LIMIT ?3"#,
)
.change_context(Error)?;
let result = stmt
.query_map(params![embedding_bytes, threshold, limit], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
})
.change_context(Error)?
.map(|r| r.change_context(Error))
.collect::<Result<Vec<_>>>()?;
// let mut results = Vec::new();
// for result in result_iter {
// results.push(result.change_context(Error)?);
// }
Ok(result)
}
pub fn query_similarity(&self, embedding: &ndarray::Array1<f32>) {
let embedding_bytes =
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
.change_context(Error)
.unwrap()
.serialize()
.change_context(Error)
.unwrap();
let mut stmt = self
.conn
.prepare(
r#"
SELECT face_id,
cosine_similarity(?1, embedding)
FROM embeddings
"#,
)
.change_context(Error)
.unwrap();
let result_iter = stmt
.query_map(params![embedding_bytes], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
})
.change_context(Error)
.unwrap();
for result in result_iter {
println!("{:?}", result);
}
}
/// Get all embeddings for a specific model
pub fn get_all_embeddings(&self, model_name: Option<&str>) -> Result<Vec<EmbeddingRecord>> {
let mut embeddings = Vec::new();
if let Some(model) = model_name {
let mut stmt = self.conn.prepare(
"SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE model_name = ?1"
).change_context(Error)?;
let embedding_iter = stmt
.query_map(params![model], |row| {
let embedding_bytes: Vec<u8> = row.get(2)?;
let embedding: ndarray::Array1<f32> = {
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
.change_context(Error)
.unwrap();
sf.tensor::<f32, ndarray::Ix1>("embedding")
.unwrap()
.to_owned()
};
Ok(EmbeddingRecord {
id: row.get(0)?,
face_id: row.get(1)?,
embedding,
model_name: row.get(3)?,
created_at: row.get(4)?,
})
})
.change_context(Error)?;
for embedding in embedding_iter {
embeddings.push(embedding.change_context(Error)?);
}
} else {
let mut stmt = self
.conn
.prepare("SELECT id, face_id, embedding, model_name, created_at FROM embeddings")
.change_context(Error)?;
let embedding_iter = stmt
.query_map([], |row| {
let embedding_bytes: Vec<u8> = row.get(2)?;
let embedding: ndarray::Array1<f32> = {
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
.change_context(Error)
.unwrap();
sf.tensor::<f32, ndarray::Ix1>("embedding")
.unwrap()
.to_owned()
};
Ok(EmbeddingRecord {
id: row.get(0)?,
face_id: row.get(1)?,
embedding,
model_name: row.get(3)?,
created_at: row.get(4)?,
})
})
.change_context(Error)?;
for embedding in embedding_iter {
embeddings.push(embedding.change_context(Error)?);
}
}
Ok(embeddings)
}
}
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
use rusqlite::functions::*;
db.create_scalar_function(
"cosine_similarity",
2,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
move |ctx| {
if ctx.len() != 2 {
return Err(rusqlite::Error::UserFunctionError(
"cosine_similarity requires exactly 2 arguments".into(),
));
}
let array_1 = ctx.get_raw(0).as_blob()?;
let array_2 = ctx.get_raw(1).as_blob()?;
let array_1_st = ndarray_safetensors::SafeArraysView::from_bytes(array_1)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let array_2_st = ndarray_safetensors::SafeArraysView::from_bytes(array_2)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let array_view_1 = array_1_st
.tensor_by_index::<f32, ndarray::Ix1>(0)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let array_view_2 = array_2_st
.tensor_by_index::<f32, ndarray::Ix1>(0)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let similarity = array_view_1
.cosine_similarity(array_view_2)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
Ok(similarity)
},
)
.change_context(Error)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_database_creation() -> Result<()> {
let db = FaceDatabase::in_memory()?;
let (images, faces, landmarks, embeddings) = db.get_stats()?;
assert_eq!(images, 0);
assert_eq!(faces, 0);
assert_eq!(landmarks, 0);
assert_eq!(embeddings, 0);
Ok(())
}
#[test]
fn test_store_and_retrieve_image() -> Result<()> {
let db = FaceDatabase::in_memory()?;
let image_id = db.store_image("/path/to/image.jpg", 800, 600)?;
let image = db.get_image(image_id)?.unwrap();
assert_eq!(image.file_path, "/path/to/image.jpg");
assert_eq!(image.width, 800);
assert_eq!(image.height, 600);
Ok(())
}
}