feat: save safetensors to the database
This commit is contained in:
137
src/database.rs
137
src/database.rs
@@ -2,6 +2,7 @@ use crate::errors::{Error, Result};
|
||||
use crate::facedet::{FaceDetectionOutput, FaceLandmarks};
|
||||
use bounding_box::Aabb2;
|
||||
use error_stack::ResultExt;
|
||||
use ndarray_math::CosineSimilarity;
|
||||
use rusqlite::{Connection, OptionalExtension, params};
|
||||
use std::path::Path;
|
||||
|
||||
@@ -55,7 +56,7 @@ pub struct LandmarkRecord {
|
||||
pub struct EmbeddingRecord {
|
||||
pub id: i64,
|
||||
pub face_id: i64,
|
||||
pub embedding: Vec<f32>,
|
||||
pub embedding: ndarray::Array1<f32>,
|
||||
pub model_name: String,
|
||||
pub created_at: String,
|
||||
}
|
||||
@@ -64,6 +65,7 @@ impl FaceDatabase {
|
||||
/// Create a new database connection and initialize tables
|
||||
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
|
||||
let conn = Connection::open(db_path).change_context(Error)?;
|
||||
add_sqlite_cosine_similarity(&conn).change_context(Error)?;
|
||||
let db = Self { conn };
|
||||
db.create_tables()?;
|
||||
Ok(db)
|
||||
@@ -431,10 +433,16 @@ impl FaceDatabase {
|
||||
let embedding_iter = stmt
|
||||
.query_map(params![face_id], |row| {
|
||||
let embedding_bytes: Vec<u8> = row.get(2)?;
|
||||
let embedding: Vec<f32> = embedding_bytes
|
||||
.chunks_exact(4)
|
||||
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
|
||||
.collect();
|
||||
let embedding: ndarray::Array1<f32> = {
|
||||
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
|
||||
.change_context(Error)
|
||||
// .change_context(Error)?
|
||||
.unwrap();
|
||||
sf.tensor::<f32, ndarray::Ix1>("embedding")
|
||||
// .change_context(Error)?
|
||||
.unwrap()
|
||||
.to_owned()
|
||||
};
|
||||
|
||||
Ok(EmbeddingRecord {
|
||||
id: row.get(0)?,
|
||||
@@ -454,46 +462,6 @@ impl FaceDatabase {
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
/// Find similar faces by embedding (using cosine similarity)
|
||||
pub fn find_similar_faces(
|
||||
&self,
|
||||
query_embedding: &[f32],
|
||||
threshold: f32,
|
||||
limit: usize,
|
||||
) -> Result<Vec<(i64, f32)>> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("SELECT face_id, embedding FROM embeddings")
|
||||
.change_context(Error)?;
|
||||
|
||||
let embedding_iter = stmt
|
||||
.query_map([], |row| {
|
||||
let face_id: i64 = row.get(0)?;
|
||||
let embedding_bytes: Vec<u8> = row.get(1)?;
|
||||
let embedding: Vec<f32> = embedding_bytes
|
||||
.chunks_exact(4)
|
||||
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
|
||||
.collect();
|
||||
Ok((face_id, embedding))
|
||||
})
|
||||
.change_context(Error)?;
|
||||
|
||||
let mut similarities = Vec::new();
|
||||
for result in embedding_iter {
|
||||
let (face_id, embedding) = result.change_context(Error)?;
|
||||
let similarity = cosine_similarity(query_embedding, &embedding);
|
||||
if similarity >= threshold {
|
||||
similarities.push((face_id, similarity));
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by similarity (descending) and limit results
|
||||
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
similarities.truncate(limit);
|
||||
|
||||
Ok(similarities)
|
||||
}
|
||||
|
||||
/// Get database statistics
|
||||
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
|
||||
let images: usize = self
|
||||
@@ -518,6 +486,85 @@ impl FaceDatabase {
|
||||
|
||||
Ok((images, faces, landmarks, embeddings))
|
||||
}
|
||||
|
||||
/// Find similar faces based on cosine similarity of embeddings
|
||||
/// Return ids and similarity scores of similar faces
|
||||
pub fn find_similar_faces(
|
||||
&self,
|
||||
embedding: &ndarray::Array1<f32>,
|
||||
threshold: f32,
|
||||
limit: usize,
|
||||
) -> Result<Vec<(i64, f32)>> {
|
||||
// Serialize the query embedding to bytes
|
||||
let embedding_bytes =
|
||||
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
|
||||
.change_context(Error)?
|
||||
.serialize()
|
||||
.change_context(Error)?;
|
||||
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#" SELECT face_id, cosine_similarity(?1, embedding) as similarity
|
||||
FROM embeddings
|
||||
WHERE cosine_similarity(?1, embedding) >= ?2
|
||||
ORDER BY similarity DESC
|
||||
LIMIT ?3"#,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
let result = stmt
|
||||
.query_map(params![embedding_bytes, threshold, limit], |row| {
|
||||
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
|
||||
})
|
||||
.change_context(Error)?
|
||||
.map(|r| r.change_context(Error))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
// let mut results = Vec::new();
|
||||
// for result in result_iter {
|
||||
// results.push(result.change_context(Error)?);
|
||||
// }
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
|
||||
use rusqlite::functions::*;
|
||||
db.create_scalar_function(
|
||||
"cosine_similarity",
|
||||
2,
|
||||
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
|
||||
move |ctx| {
|
||||
if ctx.len() != 2 {
|
||||
return Err(rusqlite::Error::UserFunctionError(
|
||||
"cosine_similarity requires exactly 2 arguments".into(),
|
||||
));
|
||||
}
|
||||
let array_1 = ctx.get_raw(0).as_blob()?;
|
||||
let array_2 = ctx.get_raw(1).as_blob()?;
|
||||
|
||||
let array_1_st = ndarray_safetensors::SafeArraysView::from_bytes(array_1)
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
let array_2_st = ndarray_safetensors::SafeArraysView::from_bytes(array_2)
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
|
||||
let array_view_1 = array_1_st
|
||||
.tensor::<f32, ndarray::Ix1>("embedding")
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
let array_view_2 = array_2_st
|
||||
.tensor::<f32, ndarray::Ix1>("embedding")
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
|
||||
let similarity = array_view_1
|
||||
.cosine_similarity(array_view_2)
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
|
||||
Ok(similarity)
|
||||
},
|
||||
)
|
||||
.change_context(Error)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
11
src/main.rs
11
src/main.rs
@@ -14,7 +14,7 @@ const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx")
|
||||
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
|
||||
pub fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("error")
|
||||
.with_env_filter("info")
|
||||
.with_thread_ids(true)
|
||||
.with_thread_names(true)
|
||||
.with_target(false)
|
||||
@@ -319,12 +319,9 @@ fn run_query(query: cli::Query) -> Result<()> {
|
||||
embedding.model_name,
|
||||
embedding.created_at
|
||||
);
|
||||
if query.show_embeddings {
|
||||
println!(
|
||||
" Values: {:?}",
|
||||
&embedding.embedding[..embedding.embedding.len().min(10)]
|
||||
);
|
||||
}
|
||||
// if query.show_embeddings {
|
||||
// println!(" Values: {:?}", &embedding.embedding);
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user