diff --git a/Cargo.toml b/Cargo.toml index 2177761..4e9ca64 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,7 @@ nalgebra = { workspace = true } ndarray = "0.16.1" ndarray-image = { workspace = true } ndarray-resize = { workspace = true } -rusqlite = { version = "0.37.0", features = ["modern-full"] } +rusqlite = { version = "0.37.0", features = ["functions", "modern-full"] } tap = "1.0.1" thiserror = "2.0" tokio = "1.43.1" diff --git a/flake.nix b/flake.nix index 3d82597..1c2275f 100644 --- a/flake.nix +++ b/flake.nix @@ -114,16 +114,17 @@ stdenv = p: p.clangStdenv; doCheck = false; LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib"; - ORT_LIB_LOCATION = "${patchedOnnxruntime}"; - ORT_ENV_SYSTEM_LIB_LOCATION = "${patchedOnnxruntime}/lib"; - ORT_ENV_PREFER_DYNAMIC_LINK = true; + # ORT_LIB_LOCATION = "${patchedOnnxruntime}"; + # ORT_ENV_SYSTEM_LIB_LOCATION = "${patchedOnnxruntime}/lib"; + # ORT_ENV_PREFER_DYNAMIC_LINK = true; nativeBuildInputs = with pkgs; [ cmake pkg-config ]; buildInputs = with pkgs; [ - # onnxruntime + patchedOnnxruntime + sqlite ] ++ (lib.optionals pkgs.stdenv.isDarwin [ libiconv diff --git a/justfile b/justfile index 8cab077..4a7a44f 100644 --- a/justfile +++ b/justfile @@ -9,5 +9,5 @@ open: bench: cargo build --release BINARY="" hyperfine --warmup 3 --export-markdown benchmark.md \ - "$CARGO_TARGET_DIR/release/detector detect -f coreml selfie.jpg" \ - "$CARGO_TARGET_DIR/release/detector detect -f coreml -b 16 selfie.jpg" + "$CARGO_TARGET_DIR/release/detector detect -f cpu selfie.jpg" \ + "$CARGO_TARGET_DIR/release/detector detect -f cpu -b 1 selfie.jpg" diff --git a/ndarray-safetensors/src/lib.rs b/ndarray-safetensors/src/lib.rs index 395d553..5a17d64 100644 --- a/ndarray-safetensors/src/lib.rs +++ b/ndarray-safetensors/src/lib.rs @@ -68,17 +68,17 @@ use safetensors::tensor::SafeTensors; /// let view = SafeArrayView::from_bytes(&bytes).unwrap(); /// let tensor: ndarray::ArrayView2 = view.tensor("data").unwrap(); /// ``` -pub struct SafeArrayView<'a> { +pub struct SafeArraysView<'a> { pub tensors: SafeTensors<'a>, } -impl<'a> SafeArrayView<'a> { +impl<'a> SafeArraysView<'a> { fn new(tensors: SafeTensors<'a>) -> Self { Self { tensors } } /// Create a SafeArrayView from serialized bytes - pub fn from_bytes(bytes: &'a [u8]) -> Result> { + pub fn from_bytes(bytes: &'a [u8]) -> Result> { let tensors = SafeTensors::deserialize(bytes)?; Ok(Self::new(tensors)) } @@ -413,7 +413,7 @@ fn test_serialize_safe_arrays() { assert!(!serialized.is_empty()); // Deserialize to check if it works - let deserialized = SafeArrayView::from_bytes(&serialized).unwrap(); + let deserialized = SafeArraysView::from_bytes(&serialized).unwrap(); assert_eq!(deserialized.len(), 2); assert_eq!( deserialized diff --git a/src/database.rs b/src/database.rs index 43e356b..6338cce 100644 --- a/src/database.rs +++ b/src/database.rs @@ -2,6 +2,7 @@ use crate::errors::{Error, Result}; use crate::facedet::{FaceDetectionOutput, FaceLandmarks}; use bounding_box::Aabb2; use error_stack::ResultExt; +use ndarray_math::CosineSimilarity; use rusqlite::{Connection, OptionalExtension, params}; use std::path::Path; @@ -55,7 +56,7 @@ pub struct LandmarkRecord { pub struct EmbeddingRecord { pub id: i64, pub face_id: i64, - pub embedding: Vec, + pub embedding: ndarray::Array1, pub model_name: String, pub created_at: String, } @@ -64,6 +65,7 @@ impl FaceDatabase { /// Create a new database connection and initialize tables pub fn new>(db_path: P) -> Result { let conn = Connection::open(db_path).change_context(Error)?; + add_sqlite_cosine_similarity(&conn).change_context(Error)?; let db = Self { conn }; db.create_tables()?; Ok(db) @@ -431,10 +433,16 @@ impl FaceDatabase { let embedding_iter = stmt .query_map(params![face_id], |row| { let embedding_bytes: Vec = row.get(2)?; - let embedding: Vec = embedding_bytes - .chunks_exact(4) - .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) - .collect(); + let embedding: ndarray::Array1 = { + let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes) + .change_context(Error) + // .change_context(Error)? + .unwrap(); + sf.tensor::("embedding") + // .change_context(Error)? + .unwrap() + .to_owned() + }; Ok(EmbeddingRecord { id: row.get(0)?, @@ -454,46 +462,6 @@ impl FaceDatabase { Ok(embeddings) } - /// Find similar faces by embedding (using cosine similarity) - pub fn find_similar_faces( - &self, - query_embedding: &[f32], - threshold: f32, - limit: usize, - ) -> Result> { - let mut stmt = self - .conn - .prepare("SELECT face_id, embedding FROM embeddings") - .change_context(Error)?; - - let embedding_iter = stmt - .query_map([], |row| { - let face_id: i64 = row.get(0)?; - let embedding_bytes: Vec = row.get(1)?; - let embedding: Vec = embedding_bytes - .chunks_exact(4) - .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) - .collect(); - Ok((face_id, embedding)) - }) - .change_context(Error)?; - - let mut similarities = Vec::new(); - for result in embedding_iter { - let (face_id, embedding) = result.change_context(Error)?; - let similarity = cosine_similarity(query_embedding, &embedding); - if similarity >= threshold { - similarities.push((face_id, similarity)); - } - } - - // Sort by similarity (descending) and limit results - similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - similarities.truncate(limit); - - Ok(similarities) - } - /// Get database statistics pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> { let images: usize = self @@ -518,6 +486,85 @@ impl FaceDatabase { Ok((images, faces, landmarks, embeddings)) } + + /// Find similar faces based on cosine similarity of embeddings + /// Return ids and similarity scores of similar faces + pub fn find_similar_faces( + &self, + embedding: &ndarray::Array1, + threshold: f32, + limit: usize, + ) -> Result> { + // Serialize the query embedding to bytes + let embedding_bytes = + ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())]) + .change_context(Error)? + .serialize() + .change_context(Error)?; + + let mut stmt = self + .conn + .prepare( + r#" SELECT face_id, cosine_similarity(?1, embedding) as similarity + FROM embeddings + WHERE cosine_similarity(?1, embedding) >= ?2 + ORDER BY similarity DESC + LIMIT ?3"#, + ) + .change_context(Error)?; + + let result = stmt + .query_map(params![embedding_bytes, threshold, limit], |row| { + Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?)) + }) + .change_context(Error)? + .map(|r| r.change_context(Error)) + .collect::>>()?; + + // let mut results = Vec::new(); + // for result in result_iter { + // results.push(result.change_context(Error)?); + // } + + Ok(result) + } +} + +fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> { + use rusqlite::functions::*; + db.create_scalar_function( + "cosine_similarity", + 2, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + move |ctx| { + if ctx.len() != 2 { + return Err(rusqlite::Error::UserFunctionError( + "cosine_similarity requires exactly 2 arguments".into(), + )); + } + let array_1 = ctx.get_raw(0).as_blob()?; + let array_2 = ctx.get_raw(1).as_blob()?; + + let array_1_st = ndarray_safetensors::SafeArraysView::from_bytes(array_1) + .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?; + let array_2_st = ndarray_safetensors::SafeArraysView::from_bytes(array_2) + .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?; + + let array_view_1 = array_1_st + .tensor::("embedding") + .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?; + let array_view_2 = array_2_st + .tensor::("embedding") + .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?; + + let similarity = array_view_1 + .cosine_similarity(array_view_2) + .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?; + + Ok(similarity) + }, + ) + .change_context(Error) } #[cfg(test)] diff --git a/src/main.rs b/src/main.rs index 5d5e64e..c661f14 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,7 +14,7 @@ const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx") const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx"); pub fn main() -> Result<()> { tracing_subscriber::fmt() - .with_env_filter("error") + .with_env_filter("info") .with_thread_ids(true) .with_thread_names(true) .with_target(false) @@ -319,12 +319,9 @@ fn run_query(query: cli::Query) -> Result<()> { embedding.model_name, embedding.created_at ); - if query.show_embeddings { - println!( - " Values: {:?}", - &embedding.embedding[..embedding.embedding.len().min(10)] - ); - } + // if query.show_embeddings { + // println!(" Values: {:?}", &embedding.embedding); + // } } }