From a340552257943091cf07e035401ac35b097a1dbd Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Sat, 13 Sep 2025 17:45:55 +0530 Subject: [PATCH] feat(cli): add clustering command with K-means support --- Cargo.lock | 111 ++++++++++++++++++++++++++++++ Cargo.toml | 2 + src/bin/detector-cli/cli.rs | 18 +++++ src/bin/detector-cli/main.rs | 127 ++++++++++++++++++++++++++++++++++- src/database.rs | 87 +++++++++++++++++++++--- 5 files changed, 336 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 44191b4..ee5f788 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1272,6 +1272,7 @@ dependencies = [ "imageproc", "itertools 0.14.0", "linfa", + "linfa-clustering", "mnn", "mnn-bridge", "mnn-sync", @@ -2809,6 +2810,16 @@ dependencies = [ "mutate_once", ] +[[package]] +name = "kdtree" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f0a0e9f770b65bac9aad00f97a67ab5c5319effed07f6da385da3c2115e47ba" +dependencies = [ + "num-traits", + "thiserror 1.0.69", +] + [[package]] name = "khronos-egl" version = "6.0.0" @@ -2930,6 +2941,50 @@ dependencies = [ "thiserror 2.0.15", ] +[[package]] +name = "linfa-clustering" +version = "0.7.1" +source = "git+https://github.com/relf/linfa?branch=upgrade-ndarray-0.16#c1fbee7c54e806de3f5fb2c5240ce163d000f1ba" +dependencies = [ + "linfa", + "linfa-linalg", + "linfa-nn", + "ndarray", + "ndarray-rand", + "ndarray-stats", + "noisy_float", + "num-traits", + "rand_xoshiro", + "space", + "thiserror 2.0.15", +] + +[[package]] +name = "linfa-linalg" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02a834c0ec063937688a0d13573aa515ab8c425bd8de3154b908dd3b9c197dc4" +dependencies = [ + "ndarray", + "num-traits", + "thiserror 1.0.69", +] + +[[package]] +name = "linfa-nn" +version = "0.7.2" +source = "git+https://github.com/relf/linfa?branch=upgrade-ndarray-0.16#c1fbee7c54e806de3f5fb2c5240ce163d000f1ba" +dependencies = [ + "kdtree", + "linfa", + "ndarray", + "ndarray-stats", + "noisy_float", + "num-traits", + "order-stat", + "thiserror 2.0.15", +] + [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -3280,6 +3335,17 @@ dependencies = [ "zip", ] +[[package]] +name = "ndarray-rand" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f093b3db6fd194718dcdeea6bd8c829417deae904e3fcc7732dabcd4416d25d8" +dependencies = [ + "ndarray", + "rand 0.8.5", + "rand_distr", +] + [[package]] name = "ndarray-resize" version = "0.1.0" @@ -3303,6 +3369,21 @@ dependencies = [ "thiserror 2.0.15", ] +[[package]] +name = "ndarray-stats" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17ebbe97acce52d06aebed4cd4a87c0941f4b2519b59b82b4feb5bd0ce003dfd" +dependencies = [ + "indexmap 2.10.0", + "itertools 0.13.0", + "ndarray", + "noisy_float", + "num-integer", + "num-traits", + "rand 0.8.5", +] + [[package]] name = "ndcv-bridge" version = "0.1.0" @@ -3395,6 +3476,15 @@ dependencies = [ "memoffset", ] +[[package]] +name = "noisy_float" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978fe6e6ebc0bf53de533cd456ca2d9de13de13856eda1518a285d7705a213af" +dependencies = [ + "num-traits", +] + [[package]] name = "nom" version = "7.1.3" @@ -3889,6 +3979,12 @@ dependencies = [ "libredox", ] +[[package]] +name = "order-stat" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efa535d5117d3661134dbf1719b6f0ffe06f2375843b13935db186cd094105eb" + [[package]] name = "ordered-float" version = "5.0.0" @@ -4437,6 +4533,15 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "rand_xoshiro" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" +dependencies = [ + "rand_core 0.6.4", +] + [[package]] name = "range-alloc" version = "0.1.4" @@ -5108,6 +5213,12 @@ dependencies = [ "x11rb", ] +[[package]] +name = "space" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e990cc6cb89a82d70fe722cd7811dbce48a72bbfaebd623e58f142b6db28428f" + [[package]] name = "spin" version = "0.9.8" diff --git a/Cargo.toml b/Cargo.toml index 4c7582b..131dd16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ edition = "2024" [patch.crates-io] linfa = { git = "https://github.com/relf/linfa", branch = "upgrade-ndarray-0.16" } +linfa-clustering = { git = "https://github.com/relf/linfa", branch = "upgrade-ndarray-0.16" } [workspace.dependencies] divan = { version = "0.1.21" } @@ -90,6 +91,7 @@ rfd = "0.15" futures = "0.3" imageproc = "0.25" linfa = "0.7.1" +linfa-clustering = "0.7.1" [profile.release] debug = true diff --git a/src/bin/detector-cli/cli.rs b/src/bin/detector-cli/cli.rs index 78bbd65..30bf7f3 100644 --- a/src/bin/detector-cli/cli.rs +++ b/src/bin/detector-cli/cli.rs @@ -21,6 +21,8 @@ pub enum SubCommand { Stats(Stats), #[clap(name = "compare")] Compare(Compare), + #[clap(name = "cluster")] + Cluster(Cluster), #[clap(name = "gui")] Gui, #[clap(name = "completions")] @@ -187,6 +189,22 @@ pub struct Compare { pub image2: PathBuf, } +#[derive(Debug, clap::Args)] +pub struct Cluster { + #[clap(short = 'd', long, default_value = "face_detections.db")] + pub database: PathBuf, + #[clap(short, long, default_value_t = 5)] + pub clusters: usize, + #[clap(short, long, default_value_t = 100)] + pub max_iterations: u64, + #[clap(short, long, default_value_t = 1e-4)] + pub tolerance: f64, + #[clap(long, default_value = "facenet")] + pub model_name: String, + #[clap(short, long)] + pub output: Option, +} + impl Cli { pub fn completions(shell: clap_complete::Shell) { let mut command = ::command(); diff --git a/src/bin/detector-cli/main.rs b/src/bin/detector-cli/main.rs index 4cbd9df..843079a 100644 --- a/src/bin/detector-cli/main.rs +++ b/src/bin/detector-cli/main.rs @@ -2,9 +2,11 @@ mod cli; use bounding_box::roi::MultiRoi; use detector::*; use detector::{database::FaceDatabase, facedet, facedet::FaceDetectionConfig, faceembed}; -use errors::*; + use fast_image_resize::ResizeOptions; +use linfa::prelude::*; +use linfa_clustering::KMeans; use ndarray::*; use ndarray_image::*; use ndarray_resize::NdFir; @@ -209,6 +211,9 @@ pub fn main() -> Result<()> { std::process::exit(1); } } + cli::SubCommand::Cluster(cluster) => { + run_clustering(cluster)?; + } cli::SubCommand::Completions { shell } => { cli::Cli::completions(shell); } @@ -936,3 +941,123 @@ fn run_stats(stats: cli::Stats) -> Result<()> { Ok(()) } + +fn run_clustering(cluster: cli::Cluster) -> Result<()> { + let db = FaceDatabase::new(&cluster.database).change_context(Error)?; + + // Get all embeddings for the specified model + let embeddings = db + .get_all_embeddings(Some(&cluster.model_name)) + .change_context(Error)?; + + if embeddings.is_empty() { + println!("No embeddings found for model '{}'", cluster.model_name); + return Ok(()); + } + + println!( + "Found {} embeddings for model '{}'", + embeddings.len(), + cluster.model_name + ); + + if embeddings.len() < cluster.clusters { + println!( + "Warning: Number of embeddings ({}) is less than requested clusters ({})", + embeddings.len(), + cluster.clusters + ); + return Ok(()); + } + + // Convert embeddings to a 2D array for clustering + let embedding_dim = embeddings[0].embedding.len(); + let mut data = Array2::::zeros((embeddings.len(), embedding_dim)); + + for (i, embedding_record) in embeddings.iter().enumerate() { + for (j, &val) in embedding_record.embedding.iter().enumerate() { + data[[i, j]] = val as f64; + } + } + + println!( + "Running K-means clustering with {} clusters...", + cluster.clusters + ); + + // Create dataset + let dataset = linfa::Dataset::new(data, Array1::::zeros(embeddings.len())); + + // Configure and run K-means + let model = KMeans::params(cluster.clusters) + .max_n_iterations(cluster.max_iterations) + .tolerance(cluster.tolerance) + .fit(&dataset) + .change_context(Error) + .attach_printable("Failed to run K-means clustering")?; + + // Get cluster assignments + let predictions = model.predict(&dataset); + + // Group results by cluster + let mut clusters: std::collections::HashMap> = + std::collections::HashMap::new(); + + for (i, &cluster_id) in predictions.iter().enumerate() { + let face_id = embeddings[i].face_id; + + // Get image path for this face + let image_info = db.get_image_for_face(face_id).change_context(Error)?; + let image_path = image_info + .map(|info| info.file_path) + .unwrap_or_else(|| "Unknown".to_string()); + + clusters + .entry(cluster_id) + .or_insert_with(Vec::new) + .push((face_id, image_path)); + } + + // Print results + println!("\nClustering Results:"); + for cluster_id in 0..cluster.clusters { + if let Some(faces) = clusters.get(&cluster_id) { + println!("\nCluster {}: {} faces", cluster_id, faces.len()); + for (face_id, image_path) in faces { + println!(" Face ID: {}, Image: {}", face_id, image_path); + } + } else { + println!("\nCluster {}: 0 faces", cluster_id); + } + } + + // Optionally save results to file + if let Some(output_path) = &cluster.output { + use std::io::Write; + let mut file = std::fs::File::create(output_path) + .change_context(Error) + .attach_printable("Failed to create output file")?; + + writeln!(file, "K-means Clustering Results").change_context(Error)?; + writeln!(file, "Model: {}", cluster.model_name).change_context(Error)?; + writeln!(file, "Total embeddings: {}", embeddings.len()).change_context(Error)?; + writeln!(file, "Number of clusters: {}", cluster.clusters).change_context(Error)?; + writeln!(file, "").change_context(Error)?; + + for cluster_id in 0..cluster.clusters { + if let Some(faces) = clusters.get(&cluster_id) { + writeln!(file, "Cluster {}: {} faces", cluster_id, faces.len()) + .change_context(Error)?; + for (face_id, image_path) in faces { + writeln!(file, " Face ID: {}, Image: {}", face_id, image_path) + .change_context(Error)?; + } + writeln!(file, "").change_context(Error)?; + } + } + + println!("\nResults saved to: {:?}", output_path); + } + + Ok(()) +} diff --git a/src/database.rs b/src/database.rs index 048e54d..780fb32 100644 --- a/src/database.rs +++ b/src/database.rs @@ -65,14 +65,15 @@ impl FaceDatabase { /// Create a new database connection and initialize tables pub fn new>(db_path: P) -> Result { let conn = Connection::open(db_path).change_context(Error)?; - unsafe { - let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?; - conn.load_extension( - "/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib", - None::<&str>, - ) - .change_context(Error)?; - } + // Temporarily disable extension loading for clustering demo + // unsafe { + // let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?; + // conn.load_extension( + // "/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib", + // None::<&str>, + // ) + // .change_context(Error)?; + // } let db = Self { conn }; db.create_tables()?; Ok(db) @@ -594,6 +595,76 @@ impl FaceDatabase { println!("{:?}", result); } } + + /// Get all embeddings for a specific model + pub fn get_all_embeddings(&self, model_name: Option<&str>) -> Result> { + let mut embeddings = Vec::new(); + + if let Some(model) = model_name { + let mut stmt = self.conn.prepare( + "SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE model_name = ?1" + ).change_context(Error)?; + + let embedding_iter = stmt + .query_map(params![model], |row| { + let embedding_bytes: Vec = row.get(2)?; + let embedding: ndarray::Array1 = { + let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes) + .change_context(Error) + .unwrap(); + sf.tensor::("embedding") + .unwrap() + .to_owned() + }; + + Ok(EmbeddingRecord { + id: row.get(0)?, + face_id: row.get(1)?, + embedding, + model_name: row.get(3)?, + created_at: row.get(4)?, + }) + }) + .change_context(Error)?; + + for embedding in embedding_iter { + embeddings.push(embedding.change_context(Error)?); + } + } else { + let mut stmt = self + .conn + .prepare("SELECT id, face_id, embedding, model_name, created_at FROM embeddings") + .change_context(Error)?; + + let embedding_iter = stmt + .query_map([], |row| { + let embedding_bytes: Vec = row.get(2)?; + let embedding: ndarray::Array1 = { + let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes) + .change_context(Error) + .unwrap(); + sf.tensor::("embedding") + .unwrap() + .to_owned() + }; + + Ok(EmbeddingRecord { + id: row.get(0)?, + face_id: row.get(1)?, + embedding, + model_name: row.get(3)?, + created_at: row.get(4)?, + }) + }) + .change_context(Error)?; + + for embedding in embedding_iter { + embeddings.push(embedding.change_context(Error)?); + } + } + + Ok(embeddings) + } } fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {