feat: save safetensors to the database
Some checks failed
build / checks-matrix (push) Successful in 19m23s
build / codecov (push) Failing after 19m26s
docs / docs (push) Failing after 28m47s
build / checks-build (push) Has been cancelled

This commit is contained in:
uttarayan21
2025-08-20 12:14:08 +05:30
parent 37adb74adf
commit 97f64e7e10
6 changed files with 108 additions and 63 deletions

View File

@@ -2,6 +2,7 @@ use crate::errors::{Error, Result};
use crate::facedet::{FaceDetectionOutput, FaceLandmarks};
use bounding_box::Aabb2;
use error_stack::ResultExt;
use ndarray_math::CosineSimilarity;
use rusqlite::{Connection, OptionalExtension, params};
use std::path::Path;
@@ -55,7 +56,7 @@ pub struct LandmarkRecord {
pub struct EmbeddingRecord {
pub id: i64,
pub face_id: i64,
pub embedding: Vec<f32>,
pub embedding: ndarray::Array1<f32>,
pub model_name: String,
pub created_at: String,
}
@@ -64,6 +65,7 @@ impl FaceDatabase {
/// Create a new database connection and initialize tables
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
let conn = Connection::open(db_path).change_context(Error)?;
add_sqlite_cosine_similarity(&conn).change_context(Error)?;
let db = Self { conn };
db.create_tables()?;
Ok(db)
@@ -431,10 +433,16 @@ impl FaceDatabase {
let embedding_iter = stmt
.query_map(params![face_id], |row| {
let embedding_bytes: Vec<u8> = row.get(2)?;
let embedding: Vec<f32> = embedding_bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
let embedding: ndarray::Array1<f32> = {
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
.change_context(Error)
// .change_context(Error)?
.unwrap();
sf.tensor::<f32, ndarray::Ix1>("embedding")
// .change_context(Error)?
.unwrap()
.to_owned()
};
Ok(EmbeddingRecord {
id: row.get(0)?,
@@ -454,46 +462,6 @@ impl FaceDatabase {
Ok(embeddings)
}
/// Find similar faces by embedding (using cosine similarity)
pub fn find_similar_faces(
&self,
query_embedding: &[f32],
threshold: f32,
limit: usize,
) -> Result<Vec<(i64, f32)>> {
let mut stmt = self
.conn
.prepare("SELECT face_id, embedding FROM embeddings")
.change_context(Error)?;
let embedding_iter = stmt
.query_map([], |row| {
let face_id: i64 = row.get(0)?;
let embedding_bytes: Vec<u8> = row.get(1)?;
let embedding: Vec<f32> = embedding_bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Ok((face_id, embedding))
})
.change_context(Error)?;
let mut similarities = Vec::new();
for result in embedding_iter {
let (face_id, embedding) = result.change_context(Error)?;
let similarity = cosine_similarity(query_embedding, &embedding);
if similarity >= threshold {
similarities.push((face_id, similarity));
}
}
// Sort by similarity (descending) and limit results
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
similarities.truncate(limit);
Ok(similarities)
}
/// Get database statistics
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
let images: usize = self
@@ -518,6 +486,85 @@ impl FaceDatabase {
Ok((images, faces, landmarks, embeddings))
}
/// Find similar faces based on cosine similarity of embeddings
/// Return ids and similarity scores of similar faces
pub fn find_similar_faces(
&self,
embedding: &ndarray::Array1<f32>,
threshold: f32,
limit: usize,
) -> Result<Vec<(i64, f32)>> {
// Serialize the query embedding to bytes
let embedding_bytes =
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
.change_context(Error)?
.serialize()
.change_context(Error)?;
let mut stmt = self
.conn
.prepare(
r#" SELECT face_id, cosine_similarity(?1, embedding) as similarity
FROM embeddings
WHERE cosine_similarity(?1, embedding) >= ?2
ORDER BY similarity DESC
LIMIT ?3"#,
)
.change_context(Error)?;
let result = stmt
.query_map(params![embedding_bytes, threshold, limit], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
})
.change_context(Error)?
.map(|r| r.change_context(Error))
.collect::<Result<Vec<_>>>()?;
// let mut results = Vec::new();
// for result in result_iter {
// results.push(result.change_context(Error)?);
// }
Ok(result)
}
}
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
use rusqlite::functions::*;
db.create_scalar_function(
"cosine_similarity",
2,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
move |ctx| {
if ctx.len() != 2 {
return Err(rusqlite::Error::UserFunctionError(
"cosine_similarity requires exactly 2 arguments".into(),
));
}
let array_1 = ctx.get_raw(0).as_blob()?;
let array_2 = ctx.get_raw(1).as_blob()?;
let array_1_st = ndarray_safetensors::SafeArraysView::from_bytes(array_1)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let array_2_st = ndarray_safetensors::SafeArraysView::from_bytes(array_2)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let array_view_1 = array_1_st
.tensor::<f32, ndarray::Ix1>("embedding")
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let array_view_2 = array_2_st
.tensor::<f32, ndarray::Ix1>("embedding")
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let similarity = array_view_1
.cosine_similarity(array_view_2)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
Ok(similarity)
},
)
.change_context(Error)
}
#[cfg(test)]

View File

@@ -14,7 +14,7 @@ const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx")
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
pub fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter("error")
.with_env_filter("info")
.with_thread_ids(true)
.with_thread_names(true)
.with_target(false)
@@ -319,12 +319,9 @@ fn run_query(query: cli::Query) -> Result<()> {
embedding.model_name,
embedding.created_at
);
if query.show_embeddings {
println!(
" Values: {:?}",
&embedding.embedding[..embedding.embedding.len().min(10)]
);
}
// if query.show_embeddings {
// println!(" Values: {:?}", &embedding.embedding);
// }
}
}