From 47218fa69639d141eeeb258531bf549f09008405 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Tue, 19 Aug 2025 18:33:38 +0530 Subject: [PATCH] feat: Added ndarray-safetensors --- Cargo.lock | 48 +++ Cargo.toml | 10 +- ndarray-safetensors/Cargo.toml | 11 + ndarray-safetensors/src/lib.rs | 432 ++++++++++++++++++++++++++ src/cli.rs | 44 +++ src/database.rs | 548 +++++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/main.rs | 198 +++++++++++- 8 files changed, 1285 insertions(+), 7 deletions(-) create mode 100644 ndarray-safetensors/Cargo.toml create mode 100644 ndarray-safetensors/src/lib.rs create mode 100644 src/database.rs diff --git a/Cargo.lock b/Cargo.lock index ac39c0a..540a5e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -269,6 +269,20 @@ name = "bytemuck" version = "1.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3995eaeebcdf32f91f980d360f78732ddc061097ab4e39991ae7a6ace9194677" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f154e572231cb6ba2bd1176980827e3d5dc04cc183a75dea38109fbdd672d29" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "byteorder-lite" @@ -504,7 +518,9 @@ dependencies = [ "nalgebra", "ndarray", "ndarray-image", + "ndarray-math", "ndarray-resize", + "ndarray-safetensors", "ordered-float", "ort", "rusqlite", @@ -830,6 +846,7 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ + "bytemuck", "cfg-if", "crunchy", ] @@ -1414,6 +1431,16 @@ dependencies = [ "ndarray", ] +[[package]] +name = "ndarray-math" +version = "0.1.0" +source = "git+https://git.darksailor.dev/servius/ndarray-math#f047966f20835267f20e5839272b9ab36c445796" +dependencies = [ + "ndarray", + "num", + "thiserror 2.0.15", +] + [[package]] name = "ndarray-resize" version = "0.1.0" @@ -1426,6 +1453,17 @@ dependencies = [ "thiserror 2.0.15", ] +[[package]] +name = "ndarray-safetensors" +version = "0.1.0" +dependencies = [ + "bytemuck", + "half", + "ndarray", + "safetensors", + "thiserror 2.0.15", +] + [[package]] name = "new_debug_unreachable" version = "1.0.6" @@ -1983,6 +2021,16 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "safetensors" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "172dd94c5a87b5c79f945c863da53b2ebc7ccef4eca24ac63cca66a41aab2178" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "scopeguard" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index 83cbee7..2177761 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["ndarray-image", "ndarray-resize", ".", "bounding-box"] +members = ["ndarray-image", "ndarray-resize", ".", "bounding-box", "ndarray-safetensors"] [workspace.package] version = "0.1.0" @@ -50,11 +50,9 @@ bounding-box = { version = "0.1.0", path = "bounding-box" } color = "0.3.1" itertools = "0.14.0" ordered-float = "5.0.0" -ort = { version = "2.0.0-rc.10", default-features = false, features = [ - "std", - "tracing", - "ndarray", -] } +ort = { version = "2.0.0-rc.10", default-features = false, features = [ "std", "tracing", "ndarray"]} +ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" } +ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" } [profile.release] debug = true diff --git a/ndarray-safetensors/Cargo.toml b/ndarray-safetensors/Cargo.toml new file mode 100644 index 0000000..30f752f --- /dev/null +++ b/ndarray-safetensors/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "ndarray-safetensors" +version.workspace = true +edition.workspace = true + +[dependencies] +bytemuck = { version = "1.23.2" } +half = { version = "2.6.0", default-features = false, features = ["bytemuck"] } +ndarray = { version = "0.16.1", default-features = false, features = ["std"] } +safetensors = "0.6.2" +thiserror = "2.0.15" diff --git a/ndarray-safetensors/src/lib.rs b/ndarray-safetensors/src/lib.rs new file mode 100644 index 0000000..395d553 --- /dev/null +++ b/ndarray-safetensors/src/lib.rs @@ -0,0 +1,432 @@ +//! # ndarray-serialize +//! +//! A Rust library for serializing and deserializing `ndarray` arrays using the SafeTensors format. +//! +//! ## Features +//! - Serialize `ndarray::ArrayView` to SafeTensors format +//! - Deserialize SafeTensors data back to `ndarray::ArrayView` +//! - Support for multiple data types (f32, f64, i8-i64, u8-u64, f16, bf16) +//! - Zero-copy deserialization when possible +//! - Metadata support +//! +//! ## Example +//! ```rust +//! use ndarray::Array2; +//! use ndarray_safetensors::{SafeArrays, SafeArrayView}; +//! +//! // Create some data +//! let array = Array2::::zeros((3, 4)); +//! +//! // Serialize +//! let mut safe_arrays = SafeArrays::new(); +//! safe_arrays.insert_ndarray("my_tensor", array.view()).unwrap(); +//! safe_arrays.insert_metadata("author", "example"); +//! let bytes = safe_arrays.serialize().unwrap(); +//! +//! // Deserialize +//! let view = SafeArrayView::from_bytes(&bytes).unwrap(); +//! let tensor: ndarray::ArrayView2 = view.tensor("my_tensor").unwrap(); +//! assert_eq!(tensor.shape(), &[3, 4]); +//! ``` + +use safetensors::View; +use std::borrow::Cow; +use std::collections::{BTreeMap, HashMap}; + +use thiserror::Error; +/// Errors that can occur during SafeTensor operations +#[derive(Error, Debug)] +pub enum SafeTensorError { + #[error("Tensor not found: {0}")] + TensorNotFound(String), + #[error("Invalid tensor data: Got {0} Expected: {1}")] + InvalidTensorData(&'static str, String), + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), + #[error("Safetensor error: {0}")] + SafeTensor(#[from] safetensors::SafeTensorError), + #[error("ndarray::ShapeError error: {0}")] + NdarrayShapeError(#[from] ndarray::ShapeError), +} + +type Result = core::result::Result; + +use safetensors::tensor::SafeTensors; + +/// A view into SafeTensors data that provides access to ndarray tensors +/// +/// # Example +/// ```rust +/// use ndarray::Array2; +/// use ndarray_safetensors::{SafeArrays, SafeArrayView}; +/// +/// let array = Array2::::ones((2, 3)); +/// let mut safe_arrays = SafeArrays::new(); +/// safe_arrays.insert_ndarray("data", array.view()).unwrap(); +/// let bytes = safe_arrays.serialize().unwrap(); +/// +/// let view = SafeArrayView::from_bytes(&bytes).unwrap(); +/// let tensor: ndarray::ArrayView2 = view.tensor("data").unwrap(); +/// ``` +pub struct SafeArrayView<'a> { + pub tensors: SafeTensors<'a>, +} + +impl<'a> SafeArrayView<'a> { + fn new(tensors: SafeTensors<'a>) -> Self { + Self { tensors } + } + + /// Create a SafeArrayView from serialized bytes + pub fn from_bytes(bytes: &'a [u8]) -> Result> { + let tensors = SafeTensors::deserialize(bytes)?; + Ok(Self::new(tensors)) + } + + /// Get a dynamic-dimensional tensor by name + pub fn dynamic_tensor(&self, name: &str) -> Result> { + self.tensors + .tensor(name) + .map(|tensor| tensor_view_to_array_view(tensor))? + } + + /// Get a tensor with specific dimensions by name + /// + /// # Example + /// ```rust + /// # use ndarray::Array2; + /// # use ndarray_safetensors::{SafeArrays, SafeArrayView}; + /// # let array = Array2::::ones((2, 3)); + /// # let mut safe_arrays = SafeArrays::new(); + /// # safe_arrays.insert_ndarray("data", array.view()).unwrap(); + /// # let bytes = safe_arrays.serialize().unwrap(); + /// # let view = SafeArrayView::from_bytes(&bytes).unwrap(); + /// let tensor: ndarray::ArrayView2 = view.tensor("data").unwrap(); + /// ``` + pub fn tensor( + &self, + name: &str, + ) -> Result> { + Ok(self + .tensors + .tensor(name) + .map(|tensor| tensor_view_to_array_view(tensor))? + .map(|array_view| array_view.into_dimensionality::())??) + } + + /// Get an iterator over tensor names + pub fn names(&self) -> std::vec::IntoIter<&str> { + self.tensors.names().into_iter() + } + + /// Get the number of tensors + pub fn len(&self) -> usize { + self.tensors.len() + } + + /// Check if there are no tensors + pub fn is_empty(&self) -> bool { + self.tensors.is_empty() + } +} + +/// Trait for types that can be stored in SafeTensors +/// +/// Implemented for: f32, f64, i8, i16, i32, i64, u8, u16, u32, u64, f16, bf16 +pub trait STDtype: bytemuck::Pod { + fn dtype() -> safetensors::tensor::Dtype; + fn size() -> usize { + (Self::dtype().bitsize() / 8).max(1) + } +} + +macro_rules! impl_dtype { + ($($t:ty => $dtype:expr),* $(,)?) => { + $( + impl STDtype for $t { + fn dtype() -> safetensors::tensor::Dtype { + $dtype + } + } + )* + }; +} + +use safetensors::tensor::Dtype; + +impl_dtype!( + // bool => Dtype::BOOL, // idk if ndarray::ArrayD is packed + f32 => Dtype::F32, + f64 => Dtype::F64, + i8 => Dtype::I8, + i16 => Dtype::I16, + i32 => Dtype::I32, + i64 => Dtype::I64, + u8 => Dtype::U8, + u16 => Dtype::U16, + u32 => Dtype::U32, + u64 => Dtype::U64, + half::f16 => Dtype::F16, + half::bf16 => Dtype::BF16, +); + +fn tensor_view_to_array_view<'a, T: STDtype>( + tensor: safetensors::tensor::TensorView<'a>, +) -> Result> { + let shape = tensor.shape(); + let dtype = tensor.dtype(); + if T::dtype() != dtype { + return Err(SafeTensorError::InvalidTensorData( + core::any::type_name::(), + dtype.to_string(), + )); + } + + let data = tensor.data(); + let data: &[T] = bytemuck::cast_slice(data); + let array = ndarray::ArrayViewD::from_shape(shape, data)?; + Ok(array) +} + +/// Builder for creating SafeTensors data from ndarray tensors +/// +/// # Example +/// ```rust +/// use ndarray::{Array1, Array2}; +/// use ndarray_safetensors::SafeArrays; +/// +/// let mut safe_arrays = SafeArrays::new(); +/// +/// let array1 = Array1::::from_vec(vec![1.0, 2.0, 3.0]); +/// let array2 = Array2::::zeros((2, 2)); +/// +/// safe_arrays.insert_ndarray("vector", array1.view()).unwrap(); +/// safe_arrays.insert_ndarray("matrix", array2.view()).unwrap(); +/// safe_arrays.insert_metadata("version", "1.0"); +/// +/// let bytes = safe_arrays.serialize().unwrap(); +/// ``` +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct SafeArrays<'a> { + pub tensors: BTreeMap>, + pub metadata: Option>, +} + +impl<'a, K: AsRef> FromIterator<(K, SafeArray<'a>)> for SafeArrays<'a> { + fn from_iter)>>(iter: T) -> Self { + let tensors = iter + .into_iter() + .map(|(k, v)| (k.as_ref().to_owned(), v)) + .collect(); + Self { + tensors, + metadata: None, + } + } +} + +impl<'a, K: AsRef, T: IntoIterator)>> From for SafeArrays<'a> { + fn from(iter: T) -> Self { + let tensors = iter + .into_iter() + .map(|(k, v)| (k.as_ref().to_owned(), v)) + .collect(); + Self { + tensors, + metadata: None, + } + } +} + +impl<'a> SafeArrays<'a> { + /// Create a SafeArrays from an iterator of (name, ndarray::ArrayView) pairs + /// ```rust + /// use ndarray::{Array2, Array3}; + /// use ndarray_safetensors::{SafeArrays, SafeArray}; + /// let array = Array2::::zeros((3, 4)); + /// let safe_arrays = SafeArrays::from_ndarrays(vec![ + /// ("test_tensor", array.view()), + /// ("test_tensor2", array.view()), + /// ]).unwrap(); + /// ``` + + pub fn from_ndarrays< + K: AsRef, + T: STDtype, + D: ndarray::Dimension + 'a, + I: IntoIterator)>, + >( + iter: I, + ) -> Result { + let tensors = iter + .into_iter() + .map(|(k, v)| Ok((k.as_ref().to_owned(), SafeArray::from_ndarray(v)?))) + .collect::>>>()?; + Ok(Self { + tensors, + metadata: None, + }) + } +} + +// impl<'a, K: AsRef, T: IntoIterator)>> From for SafeArrays<'a> { +// fn from(iter: T) -> Self { +// let tensors = iter +// .into_iter() +// .map(|(k, v)| (k.as_ref().to_owned(), v)) +// .collect(); +// Self { +// tensors, +// metadata: None, +// } +// } +// } + +impl<'a> SafeArrays<'a> { + /// Create a new empty SafeArrays builder + pub const fn new() -> Self { + Self { + tensors: BTreeMap::new(), + metadata: None, + } + } + + /// Insert a SafeArray tensor with the given name + pub fn insert_tensor<'b: 'a>(&mut self, name: impl AsRef, tensor: SafeArray<'b>) { + self.tensors.insert(name.as_ref().to_owned(), tensor); + } + + /// Insert an ndarray tensor with the given name + /// + /// The array must be in standard layout and contiguous. + pub fn insert_ndarray<'b: 'a, T: STDtype, D: ndarray::Dimension + 'a>( + &mut self, + name: impl AsRef, + array: ndarray::ArrayView<'b, T, D>, + ) -> Result<()> { + self.insert_tensor(name, SafeArray::from_ndarray(array)?); + Ok(()) + } + + /// Insert metadata key-value pair + pub fn insert_metadata(&mut self, key: impl AsRef, value: impl AsRef) { + self.metadata + .get_or_insert_default() + .insert(key.as_ref().to_owned(), value.as_ref().to_owned()); + } + + /// Serialize all tensors and metadata to bytes + pub fn serialize(self) -> Result> { + let out = safetensors::serialize(self.tensors, self.metadata) + .map_err(SafeTensorError::SafeTensor)?; + Ok(out) + } +} + +/// A tensor that can be serialized to SafeTensors format +#[derive(Debug, Clone)] +pub struct SafeArray<'a> { + data: Cow<'a, [u8]>, + shape: Vec, + dtype: safetensors::tensor::Dtype, +} + +impl View for SafeArray<'_> { + fn dtype(&self) -> safetensors::tensor::Dtype { + self.dtype + } + + fn shape(&self) -> &[usize] { + &self.shape + } + + fn data(&self) -> Cow<'_, [u8]> { + self.data.clone() + } + + fn data_len(&self) -> usize { + self.data.len() + } +} + +impl<'a> SafeArray<'a> { + fn from_ndarray<'b: 'a, T: STDtype, D: ndarray::Dimension + 'a>( + array: ndarray::ArrayView<'b, T, D>, + ) -> Result { + let shape = array.shape().to_vec(); + let dtype = T::dtype(); + if array.ndim() == 0 { + return Err(SafeTensorError::InvalidTensorData( + core::any::type_name::(), + "Cannot insert a scalar tensor".to_string(), + )); + } + + if !array.is_standard_layout() { + return Err(SafeTensorError::InvalidTensorData( + core::any::type_name::(), + "ArrayView is not standard layout".to_string(), + )); + } + let data = + bytemuck::cast_slice(array.to_slice().ok_or(SafeTensorError::InvalidTensorData( + core::any::type_name::(), + "ArrayView is not contiguous".to_string(), + ))?); + let safe_array = SafeArray { + data: Cow::Borrowed(data), + shape, + dtype, + }; + Ok(safe_array) + } +} + +#[test] +fn test_safe_array_from_ndarray() { + use ndarray::Array2; + + let array = Array2::::zeros((3, 4)); + let safe_array = SafeArray::from_ndarray(array.view()).unwrap(); + assert_eq!(safe_array.shape, vec![3, 4]); + assert_eq!(safe_array.dtype, safetensors::tensor::Dtype::F32); + assert_eq!(safe_array.data.len(), 3 * 4 * 4); // 3x4x4 bytes for f32 +} + +#[test] +fn test_serialize_safe_arrays() { + use ndarray::{Array2, Array3}; + + let mut safe_arrays = SafeArrays::new(); + let array = Array2::::zeros((3, 4)); + let array2 = Array3::::zeros((8, 1, 9)); + safe_arrays + .insert_ndarray("test_tensor", array.view()) + .unwrap(); + safe_arrays + .insert_ndarray("test_tensor2", array2.view()) + .unwrap(); + safe_arrays.insert_metadata("author", "example"); + + let serialized = safe_arrays.serialize().unwrap(); + assert!(!serialized.is_empty()); + + // Deserialize to check if it works + let deserialized = SafeArrayView::from_bytes(&serialized).unwrap(); + assert_eq!(deserialized.len(), 2); + assert_eq!( + deserialized + .tensor::("test_tensor") + .unwrap() + .shape(), + &[3, 4] + ); + assert_eq!( + deserialized + .tensor::("test_tensor2") + .unwrap() + .shape(), + &[8, 1, 9] + ); +} diff --git a/src/cli.rs b/src/cli.rs index b198bce..1fb20d4 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -13,6 +13,12 @@ pub enum SubCommand { Detect(Detect), #[clap(name = "list")] List(List), + #[clap(name = "query")] + Query(Query), + #[clap(name = "similar")] + Similar(Similar), + #[clap(name = "stats")] + Stats(Stats), #[clap(name = "completions")] Completions { shell: clap_complete::Shell }, } @@ -58,12 +64,50 @@ pub struct Detect { pub nms_threshold: f32, #[clap(short, long, default_value_t = 8)] pub batch_size: usize, + #[clap(short = 'd', long)] + pub database: Option, + #[clap(long, default_value = "facenet")] + pub model_name: String, + #[clap(long)] + pub save_to_db: bool, pub image: PathBuf, } #[derive(Debug, clap::Args)] pub struct List {} +#[derive(Debug, clap::Args)] +pub struct Query { + #[clap(short = 'd', long, default_value = "face_detections.db")] + pub database: PathBuf, + #[clap(short, long)] + pub image_id: Option, + #[clap(short, long)] + pub face_id: Option, + #[clap(long)] + pub show_embeddings: bool, + #[clap(long)] + pub show_landmarks: bool, +} + +#[derive(Debug, clap::Args)] +pub struct Similar { + #[clap(short = 'd', long, default_value = "face_detections.db")] + pub database: PathBuf, + #[clap(short, long)] + pub face_id: i64, + #[clap(short, long, default_value_t = 0.7)] + pub threshold: f32, + #[clap(short, long, default_value_t = 10)] + pub limit: usize, +} + +#[derive(Debug, clap::Args)] +pub struct Stats { + #[clap(short = 'd', long, default_value = "face_detections.db")] + pub database: PathBuf, +} + impl Cli { pub fn completions(shell: clap_complete::Shell) { let mut command = ::command(); diff --git a/src/database.rs b/src/database.rs new file mode 100644 index 0000000..2fef122 --- /dev/null +++ b/src/database.rs @@ -0,0 +1,548 @@ +use crate::errors::{Error, Result}; +use crate::facedet::{FaceDetectionOutput, FaceLandmarks}; +use bounding_box::Aabb2; +use error_stack::ResultExt; +use rusqlite::{Connection, OptionalExtension, params}; +use std::path::Path; + +/// Database connection and operations for face detection results +pub struct FaceDatabase { + conn: Connection, +} + +/// Represents a stored image record +#[derive(Debug, Clone)] +pub struct ImageRecord { + pub id: i64, + pub file_path: String, + pub width: u32, + pub height: u32, + pub created_at: String, +} + +/// Represents a stored face detection record +#[derive(Debug, Clone)] +pub struct FaceRecord { + pub id: i64, + pub image_id: i64, + pub bbox_x1: f32, + pub bbox_y1: f32, + pub bbox_x2: f32, + pub bbox_y2: f32, + pub confidence: f32, + pub created_at: String, +} + +/// Represents stored face landmarks +#[derive(Debug, Clone)] +pub struct LandmarkRecord { + pub id: i64, + pub face_id: i64, + pub left_eye_x: f32, + pub left_eye_y: f32, + pub right_eye_x: f32, + pub right_eye_y: f32, + pub nose_x: f32, + pub nose_y: f32, + pub left_mouth_x: f32, + pub left_mouth_y: f32, + pub right_mouth_x: f32, + pub right_mouth_y: f32, +} + +/// Represents a stored face embedding +#[derive(Debug, Clone)] +pub struct EmbeddingRecord { + pub id: i64, + pub face_id: i64, + pub embedding: Vec, + pub model_name: String, + pub created_at: String, +} + +impl FaceDatabase { + /// Create a new database connection and initialize tables + pub fn new>(db_path: P) -> Result { + let conn = Connection::open(db_path).change_context(Error)?; + let db = Self { conn }; + db.create_tables()?; + Ok(db) + } + + /// Create an in-memory database for testing + pub fn in_memory() -> Result { + let conn = Connection::open_in_memory().change_context(Error)?; + let db = Self { conn }; + db.create_tables()?; + Ok(db) + } + + /// Create all necessary database tables + fn create_tables(&self) -> Result<()> { + // Images table + self.conn + .execute( + r#" + CREATE TABLE IF NOT EXISTS images ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_path TEXT NOT NULL UNIQUE, + width INTEGER NOT NULL, + height INTEGER NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + "#, + [], + ) + .change_context(Error)?; + + // Faces table + self.conn + .execute( + r#" + CREATE TABLE IF NOT EXISTS faces ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + image_id INTEGER NOT NULL, + bbox_x1 REAL NOT NULL, + bbox_y1 REAL NOT NULL, + bbox_x2 REAL NOT NULL, + bbox_y2 REAL NOT NULL, + confidence REAL NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE + ) + "#, + [], + ) + .change_context(Error)?; + + // Landmarks table + self.conn + .execute( + r#" + CREATE TABLE IF NOT EXISTS landmarks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + face_id INTEGER NOT NULL, + left_eye_x REAL NOT NULL, + left_eye_y REAL NOT NULL, + right_eye_x REAL NOT NULL, + right_eye_y REAL NOT NULL, + nose_x REAL NOT NULL, + nose_y REAL NOT NULL, + left_mouth_x REAL NOT NULL, + left_mouth_y REAL NOT NULL, + right_mouth_x REAL NOT NULL, + right_mouth_y REAL NOT NULL, + FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE + ) + "#, + [], + ) + .change_context(Error)?; + + // Embeddings table + self.conn + .execute( + r#" + CREATE TABLE IF NOT EXISTS embeddings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + face_id INTEGER NOT NULL, + embedding BLOB NOT NULL, + model_name TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE + ) + "#, + [], + ) + .change_context(Error)?; + + // Create indexes for better performance + self.conn + .execute( + "CREATE INDEX IF NOT EXISTS idx_faces_image_id ON faces (image_id)", + [], + ) + .change_context(Error)?; + + self.conn + .execute( + "CREATE INDEX IF NOT EXISTS idx_landmarks_face_id ON landmarks (face_id)", + [], + ) + .change_context(Error)?; + + self.conn + .execute( + "CREATE INDEX IF NOT EXISTS idx_embeddings_face_id ON embeddings (face_id)", + [], + ) + .change_context(Error)?; + + Ok(()) + } + + /// Store image metadata and return the image ID + pub fn store_image(&self, file_path: &str, width: u32, height: u32) -> Result { + let mut stmt = self + .conn + .prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)") + .change_context(Error)?; + + stmt.execute(params![file_path, width, height]) + .change_context(Error)?; + + Ok(self.conn.last_insert_rowid()) + } + + /// Store face detection results + pub fn store_face_detections( + &self, + image_id: i64, + detection_output: &FaceDetectionOutput, + ) -> Result> { + let mut face_ids = Vec::new(); + + for (i, bbox) in detection_output.bbox.iter().enumerate() { + let confidence = detection_output.confidence.get(i).copied().unwrap_or(0.0); + + let face_id = self.store_face(image_id, bbox, confidence)?; + face_ids.push(face_id); + + // Store landmarks if available + if let Some(landmarks) = detection_output.landmark.get(i) { + self.store_landmarks(face_id, landmarks)?; + } + } + + Ok(face_ids) + } + + /// Store a single face detection + pub fn store_face(&self, image_id: i64, bbox: &Aabb2, confidence: f32) -> Result { + let mut stmt = self + .conn + .prepare( + r#" + INSERT INTO faces (image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence) + VALUES (?1, ?2, ?3, ?4, ?5, ?6) + "#, + ) + .change_context(Error)?; + + stmt.execute(params![ + image_id, + bbox.x1() as f32, + bbox.y1() as f32, + bbox.x2() as f32, + bbox.y2() as f32, + confidence + ]) + .change_context(Error)?; + + Ok(self.conn.last_insert_rowid()) + } + + /// Store face landmarks + pub fn store_landmarks(&self, face_id: i64, landmarks: &FaceLandmarks) -> Result { + let mut stmt = self + .conn + .prepare( + r#" + INSERT INTO landmarks + (face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y, + nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11) + "#, + ) + .change_context(Error)?; + + stmt.execute(params![ + face_id, + landmarks.left_eye.x, + landmarks.left_eye.y, + landmarks.right_eye.x, + landmarks.right_eye.y, + landmarks.nose.x, + landmarks.nose.y, + landmarks.left_mouth.x, + landmarks.left_mouth.y, + landmarks.right_mouth.x, + landmarks.right_mouth.y, + ]) + .change_context(Error)?; + + Ok(self.conn.last_insert_rowid()) + } + + /// Store face embeddings + pub fn store_embeddings( + &self, + face_ids: &[i64], + embeddings: &[ndarray::Array2], + model_name: &str, + ) -> Result> { + let mut embedding_ids = Vec::new(); + + for (face_idx, embedding_batch) in embeddings.iter().enumerate() { + for (batch_idx, embedding_row) in embedding_batch.rows().into_iter().enumerate() { + let global_idx = face_idx * embedding_batch.nrows() + batch_idx; + + if global_idx >= face_ids.len() { + break; + } + + let face_id = face_ids[global_idx]; + let embedding_id = + self.store_single_embedding(face_id, embedding_row, model_name)?; + embedding_ids.push(embedding_id); + } + } + + Ok(embedding_ids) + } + + /// Store a single embedding + pub fn store_single_embedding( + &self, + face_id: i64, + embedding: ndarray::ArrayView1, + model_name: &str, + ) -> Result { + // Convert f32 slice to bytes + // let embedding_bytes: Vec = embedding.iter().flat_map(|&f| f.to_le_bytes()).collect(); + let embedding_bytes = ndarray_safetensors::SafeArrays::new(); + + let mut stmt = self + .conn + .prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)") + .change_context(Error)?; + + stmt.execute(params![face_id, embedding_bytes, model_name]) + .change_context(Error)?; + + Ok(self.conn.last_insert_rowid()) + } + + /// Get image by ID + pub fn get_image(&self, image_id: i64) -> Result> { + let mut stmt = self + .conn + .prepare("SELECT id, file_path, width, height, created_at FROM images WHERE id = ?1") + .change_context(Error)?; + + let result = stmt + .query_row(params![image_id], |row| { + Ok(ImageRecord { + id: row.get(0)?, + file_path: row.get(1)?, + width: row.get(2)?, + height: row.get(3)?, + created_at: row.get(4)?, + }) + }) + .optional() + .change_context(Error)?; + + Ok(result) + } + + /// Get all faces for an image + pub fn get_faces_for_image(&self, image_id: i64) -> Result> { + let mut stmt = self + .conn + .prepare( + r#" + SELECT id, image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence, created_at + FROM faces WHERE image_id = ?1 + "#, + ) + .change_context(Error)?; + + let face_iter = stmt + .query_map(params![image_id], |row| { + Ok(FaceRecord { + id: row.get(0)?, + image_id: row.get(1)?, + bbox_x1: row.get(2)?, + bbox_y1: row.get(3)?, + bbox_x2: row.get(4)?, + bbox_y2: row.get(5)?, + confidence: row.get(6)?, + created_at: row.get(7)?, + }) + }) + .change_context(Error)?; + + let mut faces = Vec::new(); + for face in face_iter { + faces.push(face.change_context(Error)?); + } + + Ok(faces) + } + + /// Get landmarks for a face + pub fn get_landmarks(&self, face_id: i64) -> Result> { + let mut stmt = self + .conn + .prepare( + r#" + SELECT id, face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y, + nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y + FROM landmarks WHERE face_id = ?1 + "#, + ) + .change_context(Error)?; + + let result = stmt + .query_row(params![face_id], |row| { + Ok(LandmarkRecord { + id: row.get(0)?, + face_id: row.get(1)?, + left_eye_x: row.get(2)?, + left_eye_y: row.get(3)?, + right_eye_x: row.get(4)?, + right_eye_y: row.get(5)?, + nose_x: row.get(6)?, + nose_y: row.get(7)?, + left_mouth_x: row.get(8)?, + left_mouth_y: row.get(9)?, + right_mouth_x: row.get(10)?, + right_mouth_y: row.get(11)?, + }) + }) + .optional() + .change_context(Error)?; + + Ok(result) + } + + /// Get embeddings for a face + pub fn get_embeddings(&self, face_id: i64) -> Result> { + let mut stmt = self + .conn + .prepare( + "SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE face_id = ?1", + ) + .change_context(Error)?; + + let embedding_iter = stmt + .query_map(params![face_id], |row| { + let embedding_bytes: Vec = row.get(2)?; + let embedding: Vec = embedding_bytes + .chunks_exact(4) + .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect(); + + Ok(EmbeddingRecord { + id: row.get(0)?, + face_id: row.get(1)?, + embedding, + model_name: row.get(3)?, + created_at: row.get(4)?, + }) + }) + .change_context(Error)?; + + let mut embeddings = Vec::new(); + for embedding in embedding_iter { + embeddings.push(embedding.change_context(Error)?); + } + + Ok(embeddings) + } + + /// Find similar faces by embedding (using cosine similarity) + pub fn find_similar_faces( + &self, + query_embedding: &[f32], + threshold: f32, + limit: usize, + ) -> Result> { + let mut stmt = self + .conn + .prepare("SELECT face_id, embedding FROM embeddings") + .change_context(Error)?; + + let embedding_iter = stmt + .query_map([], |row| { + let face_id: i64 = row.get(0)?; + let embedding_bytes: Vec = row.get(1)?; + let embedding: Vec = embedding_bytes + .chunks_exact(4) + .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect(); + Ok((face_id, embedding)) + }) + .change_context(Error)?; + + let mut similarities = Vec::new(); + for result in embedding_iter { + let (face_id, embedding) = result.change_context(Error)?; + let similarity = cosine_similarity(query_embedding, &embedding); + if similarity >= threshold { + similarities.push((face_id, similarity)); + } + } + + // Sort by similarity (descending) and limit results + similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + similarities.truncate(limit); + + Ok(similarities) + } + + /// Get database statistics + pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> { + let images: usize = self + .conn + .query_row("SELECT COUNT(*) FROM images", [], |row| row.get(0)) + .change_context(Error)?; + + let faces: usize = self + .conn + .query_row("SELECT COUNT(*) FROM faces", [], |row| row.get(0)) + .change_context(Error)?; + + let landmarks: usize = self + .conn + .query_row("SELECT COUNT(*) FROM landmarks", [], |row| row.get(0)) + .change_context(Error)?; + + let embeddings: usize = self + .conn + .query_row("SELECT COUNT(*) FROM embeddings", [], |row| row.get(0)) + .change_context(Error)?; + + Ok((images, faces, landmarks, embeddings)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_database_creation() -> Result<()> { + let db = FaceDatabase::in_memory()?; + let (images, faces, landmarks, embeddings) = db.get_stats()?; + assert_eq!(images, 0); + assert_eq!(faces, 0); + assert_eq!(landmarks, 0); + assert_eq!(embeddings, 0); + Ok(()) + } + + #[test] + fn test_store_and_retrieve_image() -> Result<()> { + let db = FaceDatabase::in_memory()?; + let image_id = db.store_image("/path/to/image.jpg", 800, 600)?; + + let image = db.get_image(image_id)?.unwrap(); + assert_eq!(image.file_path, "/path/to/image.jpg"); + assert_eq!(image.width, 800); + assert_eq!(image.height, 600); + + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 14dc5c0..3d8e723 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod database; pub mod errors; pub mod facedet; pub mod faceembed; diff --git a/src/main.rs b/src/main.rs index 16ae9cd..5d5e64e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,10 @@ mod cli; mod errors; use bounding_box::roi::MultiRoi; -use detector::{facedet, facedet::FaceDetectionConfig, faceembed}; +use detector::{database::FaceDatabase, facedet, facedet::FaceDetectionConfig, faceembed}; use errors::*; use fast_image_resize::ResizeOptions; + use ndarray::*; use ndarray_image::*; use ndarray_resize::NdFir; @@ -77,6 +78,15 @@ pub fn main() -> Result<()> { cli::SubCommand::List(list) => { println!("List: {:?}", list); } + cli::SubCommand::Query(query) => { + run_query(query)?; + } + cli::SubCommand::Similar(similar) => { + run_similar(similar)?; + } + cli::SubCommand::Stats(stats) => { + run_stats(stats)?; + } cli::SubCommand::Completions { shell } => { cli::Cli::completions(shell); } @@ -89,10 +99,22 @@ where D: facedet::FaceDetector, E: faceembed::FaceEmbedder, { + // Initialize database if requested + let db = if detect.save_to_db { + let db_path = detect + .database + .as_ref() + .map(|p| p.as_path()) + .unwrap_or_else(|| std::path::Path::new("face_detections.db")); + Some(FaceDatabase::new(db_path).change_context(Error)?) + } else { + None + }; let image = image::open(&detect.image) .change_context(Error) .attach_printable(detect.image.to_string_lossy().to_string())?; let image = image.into_rgb8(); + let (image_width, image_height) = image.dimensions(); let mut array = image .into_ndarray() .change_context(errors::Error) @@ -106,6 +128,26 @@ where ) .change_context(errors::Error) .attach_printable("Failed to detect faces")?; + + // Store image and face detections in database if requested + let (image_id, face_ids) = if let Some(ref database) = db { + let image_path = detect.image.to_string_lossy(); + let img_id = database + .store_image(&image_path, image_width, image_height) + .change_context(Error)?; + let face_ids = database + .store_face_detections(img_id, &output) + .change_context(Error)?; + tracing::info!( + "Stored image {} with {} faces in database", + img_id, + face_ids.len() + ); + (Some(img_id), Some(face_ids)) + } else { + (None, None) + }; + for bbox in &output.bbox { tracing::info!("Detected face: {:?}", bbox); use bounding_box::draw::*; @@ -159,6 +201,25 @@ where }) .collect::>>>()?; + // Store embeddings in database if requested + if let (Some(database), Some(face_ids)) = (&db, &face_ids) { + let embedding_ids = database + .store_embeddings(face_ids, &embeddings, &detect.model_name) + .change_context(Error)?; + tracing::info!("Stored {} embeddings in database", embedding_ids.len()); + + // Print database statistics + let (num_images, num_faces, num_landmarks, num_embeddings) = + database.get_stats().change_context(Error)?; + tracing::info!( + "Database stats - Images: {}, Faces: {}, Landmarks: {}, Embeddings: {}", + num_images, + num_faces, + num_landmarks, + num_embeddings + ); + } + let v = array.view(); if let Some(output) = detect.output { let image: image::RgbImage = v @@ -173,3 +234,138 @@ where Ok(()) } + +fn run_query(query: cli::Query) -> Result<()> { + let db = FaceDatabase::new(&query.database).change_context(Error)?; + + if let Some(image_id) = query.image_id { + if let Some(image) = db.get_image(image_id).change_context(Error)? { + println!("Image: {}", image.file_path); + println!("Dimensions: {}x{}", image.width, image.height); + println!("Created: {}", image.created_at); + + let faces = db.get_faces_for_image(image_id).change_context(Error)?; + println!("Faces found: {}", faces.len()); + + for face in faces { + println!( + " Face ID {}: bbox({:.1}, {:.1}, {:.1}, {:.1}), confidence: {:.3}", + face.id, + face.bbox_x1, + face.bbox_y1, + face.bbox_x2, + face.bbox_y2, + face.confidence + ); + + if query.show_landmarks { + if let Some(landmarks) = db.get_landmarks(face.id).change_context(Error)? { + println!( + " Landmarks: left_eye({:.1}, {:.1}), right_eye({:.1}, {:.1}), nose({:.1}, {:.1})", + landmarks.left_eye_x, + landmarks.left_eye_y, + landmarks.right_eye_x, + landmarks.right_eye_y, + landmarks.nose_x, + landmarks.nose_y + ); + } + } + + if query.show_embeddings { + let embeddings = db.get_embeddings(face.id).change_context(Error)?; + for embedding in embeddings { + println!( + " Embedding ({}): {} dims, model: {}", + embedding.id, + embedding.embedding.len(), + embedding.model_name + ); + } + } + } + } else { + println!("Image with ID {} not found", image_id); + } + } + + if let Some(face_id) = query.face_id { + if let Some(landmarks) = db.get_landmarks(face_id).change_context(Error)? { + println!( + "Landmarks for face {}: left_eye({:.1}, {:.1}), right_eye({:.1}, {:.1}), nose({:.1}, {:.1})", + face_id, + landmarks.left_eye_x, + landmarks.left_eye_y, + landmarks.right_eye_x, + landmarks.right_eye_y, + landmarks.nose_x, + landmarks.nose_y + ); + } else { + println!("No landmarks found for face {}", face_id); + } + + let embeddings = db.get_embeddings(face_id).change_context(Error)?; + println!( + "Embeddings for face {}: {} found", + face_id, + embeddings.len() + ); + for embedding in embeddings { + println!( + " Embedding {}: {} dims, model: {}, created: {}", + embedding.id, + embedding.embedding.len(), + embedding.model_name, + embedding.created_at + ); + if query.show_embeddings { + println!( + " Values: {:?}", + &embedding.embedding[..embedding.embedding.len().min(10)] + ); + } + } + } + + Ok(()) +} + +fn run_similar(similar: cli::Similar) -> Result<()> { + let db = FaceDatabase::new(&similar.database).change_context(Error)?; + + let embeddings = db.get_embeddings(similar.face_id).change_context(Error)?; + if embeddings.is_empty() { + println!("No embeddings found for face {}", similar.face_id); + return Ok(()); + } + + let query_embedding = &embeddings[0].embedding; + let similar_faces = db + .find_similar_faces(query_embedding, similar.threshold, similar.limit) + .change_context(Error)?; + + println!( + "Found {} similar faces (threshold: {:.3}):", + similar_faces.len(), + similar.threshold + ); + for (face_id, similarity) in similar_faces { + println!(" Face {}: similarity {:.3}", face_id, similarity); + } + + Ok(()) +} + +fn run_stats(stats: cli::Stats) -> Result<()> { + let db = FaceDatabase::new(&stats.database).change_context(Error)?; + let (images, faces, landmarks, embeddings) = db.get_stats().change_context(Error)?; + + println!("Database Statistics:"); + println!(" Images: {}", images); + println!(" Faces: {}", faces); + println!(" Landmarks: {}", landmarks); + println!(" Embeddings: {}", embeddings); + + Ok(()) +}