735 lines
24 KiB
Rust
735 lines
24 KiB
Rust
use crate::errors::{Error, Result};
|
|
use crate::facedet::{FaceDetectionOutput, FaceLandmarks};
|
|
use bounding_box::Aabb2;
|
|
use error_stack::ResultExt;
|
|
use ndarray_math::CosineSimilarity;
|
|
use rusqlite::{Connection, OptionalExtension, params};
|
|
use std::path::Path;
|
|
|
|
/// Database connection and operations for face detection results
|
|
pub struct FaceDatabase {
|
|
conn: Connection,
|
|
}
|
|
|
|
/// Represents a stored image record
|
|
#[derive(Debug, Clone)]
|
|
pub struct ImageRecord {
|
|
pub id: i64,
|
|
pub file_path: String,
|
|
pub width: u32,
|
|
pub height: u32,
|
|
pub created_at: String,
|
|
}
|
|
|
|
/// Represents a stored face detection record
|
|
#[derive(Debug, Clone)]
|
|
pub struct FaceRecord {
|
|
pub id: i64,
|
|
pub image_id: i64,
|
|
pub bbox_x1: f32,
|
|
pub bbox_y1: f32,
|
|
pub bbox_x2: f32,
|
|
pub bbox_y2: f32,
|
|
pub confidence: f32,
|
|
pub created_at: String,
|
|
}
|
|
|
|
/// Represents stored face landmarks
|
|
#[derive(Debug, Clone)]
|
|
pub struct LandmarkRecord {
|
|
pub id: i64,
|
|
pub face_id: i64,
|
|
pub left_eye_x: f32,
|
|
pub left_eye_y: f32,
|
|
pub right_eye_x: f32,
|
|
pub right_eye_y: f32,
|
|
pub nose_x: f32,
|
|
pub nose_y: f32,
|
|
pub left_mouth_x: f32,
|
|
pub left_mouth_y: f32,
|
|
pub right_mouth_x: f32,
|
|
pub right_mouth_y: f32,
|
|
}
|
|
|
|
/// Represents a stored face embedding
|
|
#[derive(Debug, Clone)]
|
|
pub struct EmbeddingRecord {
|
|
pub id: i64,
|
|
pub face_id: i64,
|
|
pub embedding: ndarray::Array1<f32>,
|
|
pub model_name: String,
|
|
pub created_at: String,
|
|
}
|
|
|
|
impl FaceDatabase {
|
|
/// Create a new database connection and initialize tables
|
|
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
|
|
let conn = Connection::open(db_path).change_context(Error)?;
|
|
// Temporarily disable extension loading for clustering demo
|
|
// unsafe {
|
|
// let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
|
|
// conn.load_extension(
|
|
// "/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
|
|
// None::<&str>,
|
|
// )
|
|
// .change_context(Error)?;
|
|
// }
|
|
let db = Self { conn };
|
|
db.create_tables()?;
|
|
Ok(db)
|
|
}
|
|
|
|
/// Create an in-memory database for testing
|
|
pub fn in_memory() -> Result<Self> {
|
|
let conn = Connection::open_in_memory().change_context(Error)?;
|
|
let db = Self { conn };
|
|
db.create_tables()?;
|
|
Ok(db)
|
|
}
|
|
|
|
/// Create all necessary database tables
|
|
fn create_tables(&self) -> Result<()> {
|
|
// Images table
|
|
self.conn
|
|
.execute(
|
|
r#"
|
|
CREATE TABLE IF NOT EXISTS images (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
file_path TEXT NOT NULL UNIQUE,
|
|
width INTEGER NOT NULL,
|
|
height INTEGER NOT NULL,
|
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
"#,
|
|
[],
|
|
)
|
|
.change_context(Error)?;
|
|
|
|
// Faces table
|
|
self.conn
|
|
.execute(
|
|
r#"
|
|
CREATE TABLE IF NOT EXISTS faces (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
image_id INTEGER NOT NULL,
|
|
bbox_x1 REAL NOT NULL,
|
|
bbox_y1 REAL NOT NULL,
|
|
bbox_x2 REAL NOT NULL,
|
|
bbox_y2 REAL NOT NULL,
|
|
confidence REAL NOT NULL,
|
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
|
|
)
|
|
"#,
|
|
[],
|
|
)
|
|
.change_context(Error)?;
|
|
|
|
// Landmarks table
|
|
self.conn
|
|
.execute(
|
|
r#"
|
|
CREATE TABLE IF NOT EXISTS landmarks (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
face_id INTEGER NOT NULL,
|
|
left_eye_x REAL NOT NULL,
|
|
left_eye_y REAL NOT NULL,
|
|
right_eye_x REAL NOT NULL,
|
|
right_eye_y REAL NOT NULL,
|
|
nose_x REAL NOT NULL,
|
|
nose_y REAL NOT NULL,
|
|
left_mouth_x REAL NOT NULL,
|
|
left_mouth_y REAL NOT NULL,
|
|
right_mouth_x REAL NOT NULL,
|
|
right_mouth_y REAL NOT NULL,
|
|
FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE
|
|
)
|
|
"#,
|
|
[],
|
|
)
|
|
.change_context(Error)?;
|
|
|
|
// Embeddings table
|
|
self.conn
|
|
.execute(
|
|
r#"
|
|
CREATE TABLE IF NOT EXISTS embeddings (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
face_id INTEGER NOT NULL,
|
|
embedding BLOB NOT NULL,
|
|
model_name TEXT NOT NULL,
|
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE
|
|
)
|
|
"#,
|
|
[],
|
|
)
|
|
.change_context(Error)?;
|
|
|
|
// Create indexes for better performance
|
|
self.conn
|
|
.execute(
|
|
"CREATE INDEX IF NOT EXISTS idx_faces_image_id ON faces (image_id)",
|
|
[],
|
|
)
|
|
.change_context(Error)?;
|
|
|
|
self.conn
|
|
.execute(
|
|
"CREATE INDEX IF NOT EXISTS idx_landmarks_face_id ON landmarks (face_id)",
|
|
[],
|
|
)
|
|
.change_context(Error)?;
|
|
|
|
self.conn
|
|
.execute(
|
|
"CREATE INDEX IF NOT EXISTS idx_embeddings_face_id ON embeddings (face_id)",
|
|
[],
|
|
)
|
|
.change_context(Error)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Store image metadata and return the image ID
|
|
pub fn store_image(&self, file_path: &str, width: u32, height: u32) -> Result<i64> {
|
|
let mut stmt = self
|
|
.conn
|
|
.prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)")
|
|
.change_context(Error)?;
|
|
|
|
Ok(stmt
|
|
.insert(params![file_path, width, height])
|
|
.change_context(Error)?)
|
|
}
|
|
|
|
/// Store face detection results
|
|
pub fn store_face_detections(
|
|
&self,
|
|
image_id: i64,
|
|
detection_output: &FaceDetectionOutput,
|
|
) -> Result<Vec<i64>> {
|
|
let mut face_ids = Vec::new();
|
|
|
|
for (i, bbox) in detection_output.bbox.iter().enumerate() {
|
|
let confidence = detection_output.confidence.get(i).copied().unwrap_or(0.0);
|
|
|
|
let face_id = self.store_face(image_id, bbox, confidence)?;
|
|
face_ids.push(face_id);
|
|
|
|
// Store landmarks if available
|
|
if let Some(landmarks) = detection_output.landmark.get(i) {
|
|
self.store_landmarks(face_id, landmarks)?;
|
|
}
|
|
}
|
|
|
|
Ok(face_ids)
|
|
}
|
|
|
|
/// Store a single face detection
|
|
pub fn store_face(&self, image_id: i64, bbox: &Aabb2<usize>, confidence: f32) -> Result<i64> {
|
|
let mut stmt = self
|
|
.conn
|
|
.prepare(
|
|
r#"
|
|
INSERT INTO faces (image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence)
|
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
|
"#,
|
|
)
|
|
.change_context(Error)?;
|
|
|
|
Ok(stmt
|
|
.insert(params![
|
|
image_id,
|
|
bbox.x1() as f32,
|
|
bbox.y1() as f32,
|
|
bbox.x2() as f32,
|
|
bbox.y2() as f32,
|
|
confidence
|
|
])
|
|
.change_context(Error)?)
|
|
}
|
|
|
|
/// Store face landmarks
|
|
pub fn store_landmarks(&self, face_id: i64, landmarks: &FaceLandmarks) -> Result<i64> {
|
|
let mut stmt = self
|
|
.conn
|
|
.prepare(
|
|
r#"
|
|
INSERT INTO landmarks
|
|
(face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y,
|
|
nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y)
|
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)
|
|
"#,
|
|
)
|
|
.change_context(Error)?;
|
|
|
|
Ok(stmt
|
|
.insert(params![
|
|
face_id,
|
|
landmarks.left_eye.x,
|
|
landmarks.left_eye.y,
|
|
landmarks.right_eye.x,
|
|
landmarks.right_eye.y,
|
|
landmarks.nose.x,
|
|
landmarks.nose.y,
|
|
landmarks.left_mouth.x,
|
|
landmarks.left_mouth.y,
|
|
landmarks.right_mouth.x,
|
|
landmarks.right_mouth.y,
|
|
])
|
|
.change_context(Error)?)
|
|
}
|
|
|
|
/// Store face embeddings
|
|
pub fn store_embeddings(
|
|
&self,
|
|
face_ids: &[i64],
|
|
embeddings: &[ndarray::Array2<f32>],
|
|
model_name: &str,
|
|
) -> Result<Vec<i64>> {
|
|
let mut embedding_ids = Vec::new();
|
|
|
|
for (face_idx, embedding_batch) in embeddings.iter().enumerate() {
|
|
for (batch_idx, embedding_row) in embedding_batch.rows().into_iter().enumerate() {
|
|
let global_idx = face_idx * embedding_batch.nrows() + batch_idx;
|
|
|
|
if global_idx >= face_ids.len() {
|
|
break;
|
|
}
|
|
|
|
let face_id = face_ids[global_idx];
|
|
let embedding_id =
|
|
self.store_single_embedding(face_id, embedding_row, model_name)?;
|
|
embedding_ids.push(embedding_id);
|
|
}
|
|
}
|
|
|
|
Ok(embedding_ids)
|
|
}
|
|
|
|
/// Store a single embedding
|
|
pub fn store_single_embedding(
|
|
&self,
|
|
face_id: i64,
|
|
embedding: ndarray::ArrayView1<f32>,
|
|
model_name: &str,
|
|
) -> Result<i64> {
|
|
let safe_arrays =
|
|
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)])
|
|
.change_context(Error)?;
|
|
|
|
let embedding_bytes = safe_arrays.serialize().change_context(Error)?;
|
|
|
|
let mut stmt = self
|
|
.conn
|
|
.prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)")
|
|
.change_context(Error)?;
|
|
|
|
stmt.execute(params![face_id, embedding_bytes, model_name])
|
|
.change_context(Error)?;
|
|
|
|
Ok(self.conn.last_insert_rowid())
|
|
}
|
|
|
|
/// Get image by ID
|
|
pub fn get_image(&self, image_id: i64) -> Result<Option<ImageRecord>> {
|
|
let mut stmt = self
|
|
.conn
|
|
.prepare("SELECT id, file_path, width, height, created_at FROM images WHERE id = ?1")
|
|
.change_context(Error)?;
|
|
|
|
let result = stmt
|
|
.query_row(params![image_id], |row| {
|
|
Ok(ImageRecord {
|
|
id: row.get(0)?,
|
|
file_path: row.get(1)?,
|
|
width: row.get(2)?,
|
|
height: row.get(3)?,
|
|
created_at: row.get(4)?,
|
|
})
|
|
})
|
|
.optional()
|
|
.change_context(Error)?;
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
/// Get all faces for an image
|
|
pub fn get_faces_for_image(&self, image_id: i64) -> Result<Vec<FaceRecord>> {
|
|
let mut stmt = self
|
|
.conn
|
|
.prepare(
|
|
r#"
|
|
SELECT id, image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence, created_at
|
|
FROM faces WHERE image_id = ?1
|
|
"#,
|
|
)
|
|
.change_context(Error)?;
|
|
|
|
let face_iter = stmt
|
|
.query_map(params![image_id], |row| {
|
|
Ok(FaceRecord {
|
|
id: row.get(0)?,
|
|
image_id: row.get(1)?,
|
|
bbox_x1: row.get(2)?,
|
|
bbox_y1: row.get(3)?,
|
|
bbox_x2: row.get(4)?,
|
|
bbox_y2: row.get(5)?,
|
|
confidence: row.get(6)?,
|
|
created_at: row.get(7)?,
|
|
})
|
|
})
|
|
.change_context(Error)?;
|
|
|
|
let mut faces = Vec::new();
|
|
for face in face_iter {
|
|
faces.push(face.change_context(Error)?);
|
|
}
|
|
|
|
Ok(faces)
|
|
}
|
|
|
|
/// Get landmarks for a face
|
|
pub fn get_landmarks(&self, face_id: i64) -> Result<Option<LandmarkRecord>> {
|
|
let mut stmt = self
|
|
.conn
|
|
.prepare(
|
|
r#"
|
|
SELECT id, face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y,
|
|
nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y
|
|
FROM landmarks WHERE face_id = ?1
|
|
"#,
|
|
)
|
|
.change_context(Error)?;
|
|
|
|
let result = stmt
|
|
.query_row(params![face_id], |row| {
|
|
Ok(LandmarkRecord {
|
|
id: row.get(0)?,
|
|
face_id: row.get(1)?,
|
|
left_eye_x: row.get(2)?,
|
|
left_eye_y: row.get(3)?,
|
|
right_eye_x: row.get(4)?,
|
|
right_eye_y: row.get(5)?,
|
|
nose_x: row.get(6)?,
|
|
nose_y: row.get(7)?,
|
|
left_mouth_x: row.get(8)?,
|
|
left_mouth_y: row.get(9)?,
|
|
right_mouth_x: row.get(10)?,
|
|
right_mouth_y: row.get(11)?,
|
|
})
|
|
})
|
|
.optional()
|
|
.change_context(Error)?;
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
/// Get embeddings for a face
|
|
pub fn get_embeddings(&self, face_id: i64) -> Result<Vec<EmbeddingRecord>> {
|
|
let mut stmt = self
|
|
.conn
|
|
.prepare(
|
|
"SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE face_id = ?1",
|
|
)
|
|
.change_context(Error)?;
|
|
|
|
let embedding_iter = stmt
|
|
.query_map(params![face_id], |row| {
|
|
let embedding_bytes: Vec<u8> = row.get(2)?;
|
|
let embedding: ndarray::Array1<f32> = {
|
|
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
|
|
.change_context(Error)
|
|
// .change_context(Error)?
|
|
.unwrap();
|
|
sf.tensor::<f32, ndarray::Ix1>("embedding")
|
|
// .change_context(Error)?
|
|
.unwrap()
|
|
.to_owned()
|
|
};
|
|
|
|
Ok(EmbeddingRecord {
|
|
id: row.get(0)?,
|
|
face_id: row.get(1)?,
|
|
embedding,
|
|
model_name: row.get(3)?,
|
|
created_at: row.get(4)?,
|
|
})
|
|
})
|
|
.change_context(Error)?;
|
|
|
|
let mut embeddings = Vec::new();
|
|
for embedding in embedding_iter {
|
|
embeddings.push(embedding.change_context(Error)?);
|
|
}
|
|
|
|
Ok(embeddings)
|
|
}
|
|
|
|
pub fn get_image_for_face(&self, face_id: i64) -> Result<Option<ImageRecord>> {
|
|
let mut stmt = self
|
|
.conn
|
|
.prepare(
|
|
r#"
|
|
SELECT images.id, images.file_path, images.width, images.height, images.created_at
|
|
FROM images
|
|
JOIN faces ON faces.image_id = images.id
|
|
WHERE faces.id = ?1
|
|
"#,
|
|
)
|
|
.change_context(Error)?;
|
|
|
|
let result = stmt
|
|
.query_row(params![face_id], |row| {
|
|
Ok(ImageRecord {
|
|
id: row.get(0)?,
|
|
file_path: row.get(1)?,
|
|
width: row.get(2)?,
|
|
height: row.get(3)?,
|
|
created_at: row.get(4)?,
|
|
})
|
|
})
|
|
.optional()
|
|
.change_context(Error)?;
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
/// Get database statistics
|
|
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
|
|
let images: usize = self
|
|
.conn
|
|
.query_row("SELECT COUNT(*) FROM images", [], |row| row.get(0))
|
|
.change_context(Error)?;
|
|
|
|
let faces: usize = self
|
|
.conn
|
|
.query_row("SELECT COUNT(*) FROM faces", [], |row| row.get(0))
|
|
.change_context(Error)?;
|
|
|
|
let landmarks: usize = self
|
|
.conn
|
|
.query_row("SELECT COUNT(*) FROM landmarks", [], |row| row.get(0))
|
|
.change_context(Error)?;
|
|
|
|
let embeddings: usize = self
|
|
.conn
|
|
.query_row("SELECT COUNT(*) FROM embeddings", [], |row| row.get(0))
|
|
.change_context(Error)?;
|
|
|
|
Ok((images, faces, landmarks, embeddings))
|
|
}
|
|
|
|
/// Find similar faces based on cosine similarity of embeddings
|
|
/// Return ids and similarity scores of similar faces
|
|
pub fn find_similar_faces(
|
|
&self,
|
|
embedding: &ndarray::Array1<f32>,
|
|
threshold: f32,
|
|
limit: usize,
|
|
) -> Result<Vec<(i64, f32)>> {
|
|
// Serialize the query embedding to bytes
|
|
let embedding_bytes =
|
|
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
|
|
.change_context(Error)?
|
|
.serialize()
|
|
.change_context(Error)?;
|
|
|
|
let mut stmt = self
|
|
.conn
|
|
.prepare(
|
|
r#" SELECT face_id, cosine_similarity(?1, embedding) as similarity
|
|
FROM embeddings
|
|
WHERE cosine_similarity(?1, embedding) >= ?2
|
|
ORDER BY similarity DESC
|
|
LIMIT ?3"#,
|
|
)
|
|
.change_context(Error)?;
|
|
|
|
let result = stmt
|
|
.query_map(params![embedding_bytes, threshold, limit], |row| {
|
|
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
|
|
})
|
|
.change_context(Error)?
|
|
.map(|r| r.change_context(Error))
|
|
.collect::<Result<Vec<_>>>()?;
|
|
|
|
// let mut results = Vec::new();
|
|
// for result in result_iter {
|
|
// results.push(result.change_context(Error)?);
|
|
// }
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
pub fn query_similarity(&self, embedding: &ndarray::Array1<f32>) {
|
|
let embedding_bytes =
|
|
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
|
|
.change_context(Error)
|
|
.unwrap()
|
|
.serialize()
|
|
.change_context(Error)
|
|
.unwrap();
|
|
|
|
let mut stmt = self
|
|
.conn
|
|
.prepare(
|
|
r#"
|
|
SELECT face_id,
|
|
cosine_similarity(?1, embedding)
|
|
FROM embeddings
|
|
"#,
|
|
)
|
|
.change_context(Error)
|
|
.unwrap();
|
|
|
|
let result_iter = stmt
|
|
.query_map(params![embedding_bytes], |row| {
|
|
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
|
|
})
|
|
.change_context(Error)
|
|
.unwrap();
|
|
|
|
for result in result_iter {
|
|
println!("{:?}", result);
|
|
}
|
|
}
|
|
|
|
/// Get all embeddings for a specific model
|
|
pub fn get_all_embeddings(&self, model_name: Option<&str>) -> Result<Vec<EmbeddingRecord>> {
|
|
let mut embeddings = Vec::new();
|
|
|
|
if let Some(model) = model_name {
|
|
let mut stmt = self.conn.prepare(
|
|
"SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE model_name = ?1"
|
|
).change_context(Error)?;
|
|
|
|
let embedding_iter = stmt
|
|
.query_map(params![model], |row| {
|
|
let embedding_bytes: Vec<u8> = row.get(2)?;
|
|
let embedding: ndarray::Array1<f32> = {
|
|
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
|
|
.change_context(Error)
|
|
.unwrap();
|
|
sf.tensor::<f32, ndarray::Ix1>("embedding")
|
|
.unwrap()
|
|
.to_owned()
|
|
};
|
|
|
|
Ok(EmbeddingRecord {
|
|
id: row.get(0)?,
|
|
face_id: row.get(1)?,
|
|
embedding,
|
|
model_name: row.get(3)?,
|
|
created_at: row.get(4)?,
|
|
})
|
|
})
|
|
.change_context(Error)?;
|
|
|
|
for embedding in embedding_iter {
|
|
embeddings.push(embedding.change_context(Error)?);
|
|
}
|
|
} else {
|
|
let mut stmt = self
|
|
.conn
|
|
.prepare("SELECT id, face_id, embedding, model_name, created_at FROM embeddings")
|
|
.change_context(Error)?;
|
|
|
|
let embedding_iter = stmt
|
|
.query_map([], |row| {
|
|
let embedding_bytes: Vec<u8> = row.get(2)?;
|
|
let embedding: ndarray::Array1<f32> = {
|
|
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
|
|
.change_context(Error)
|
|
.unwrap();
|
|
sf.tensor::<f32, ndarray::Ix1>("embedding")
|
|
.unwrap()
|
|
.to_owned()
|
|
};
|
|
|
|
Ok(EmbeddingRecord {
|
|
id: row.get(0)?,
|
|
face_id: row.get(1)?,
|
|
embedding,
|
|
model_name: row.get(3)?,
|
|
created_at: row.get(4)?,
|
|
})
|
|
})
|
|
.change_context(Error)?;
|
|
|
|
for embedding in embedding_iter {
|
|
embeddings.push(embedding.change_context(Error)?);
|
|
}
|
|
}
|
|
|
|
Ok(embeddings)
|
|
}
|
|
}
|
|
|
|
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
|
|
use rusqlite::functions::*;
|
|
db.create_scalar_function(
|
|
"cosine_similarity",
|
|
2,
|
|
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
|
|
move |ctx| {
|
|
if ctx.len() != 2 {
|
|
return Err(rusqlite::Error::UserFunctionError(
|
|
"cosine_similarity requires exactly 2 arguments".into(),
|
|
));
|
|
}
|
|
let array_1 = ctx.get_raw(0).as_blob()?;
|
|
let array_2 = ctx.get_raw(1).as_blob()?;
|
|
|
|
let array_1_st = ndarray_safetensors::SafeArraysView::from_bytes(array_1)
|
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
|
let array_2_st = ndarray_safetensors::SafeArraysView::from_bytes(array_2)
|
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
|
|
|
let array_view_1 = array_1_st
|
|
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
|
let array_view_2 = array_2_st
|
|
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
|
|
|
let similarity = array_view_1
|
|
.cosine_similarity(array_view_2)
|
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
|
|
|
Ok(similarity)
|
|
},
|
|
)
|
|
.change_context(Error)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_database_creation() -> Result<()> {
|
|
let db = FaceDatabase::in_memory()?;
|
|
let (images, faces, landmarks, embeddings) = db.get_stats()?;
|
|
assert_eq!(images, 0);
|
|
assert_eq!(faces, 0);
|
|
assert_eq!(landmarks, 0);
|
|
assert_eq!(embeddings, 0);
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_store_and_retrieve_image() -> Result<()> {
|
|
let db = FaceDatabase::in_memory()?;
|
|
let image_id = db.store_image("/path/to/image.jpg", 800, 600)?;
|
|
|
|
let image = db.get_image(image_id)?.unwrap();
|
|
assert_eq!(image.file_path, "/path/to/image.jpg");
|
|
assert_eq!(image.width, 800);
|
|
assert_eq!(image.height, 600);
|
|
|
|
Ok(())
|
|
}
|
|
}
|