feat(cli): add clustering command with K-means support
Some checks failed
build / checks-matrix (push) Successful in 19m25s
build / codecov (push) Failing after 19m26s
docs / docs (push) Failing after 28m52s
build / checks-build (push) Has been cancelled

This commit is contained in:
uttarayan21
2025-09-13 17:45:55 +05:30
parent aaf34ef74e
commit a340552257
5 changed files with 336 additions and 9 deletions

111
Cargo.lock generated
View File

@@ -1272,6 +1272,7 @@ dependencies = [
"imageproc",
"itertools 0.14.0",
"linfa",
"linfa-clustering",
"mnn",
"mnn-bridge",
"mnn-sync",
@@ -2809,6 +2810,16 @@ dependencies = [
"mutate_once",
]
[[package]]
name = "kdtree"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f0a0e9f770b65bac9aad00f97a67ab5c5319effed07f6da385da3c2115e47ba"
dependencies = [
"num-traits",
"thiserror 1.0.69",
]
[[package]]
name = "khronos-egl"
version = "6.0.0"
@@ -2930,6 +2941,50 @@ dependencies = [
"thiserror 2.0.15",
]
[[package]]
name = "linfa-clustering"
version = "0.7.1"
source = "git+https://github.com/relf/linfa?branch=upgrade-ndarray-0.16#c1fbee7c54e806de3f5fb2c5240ce163d000f1ba"
dependencies = [
"linfa",
"linfa-linalg",
"linfa-nn",
"ndarray",
"ndarray-rand",
"ndarray-stats",
"noisy_float",
"num-traits",
"rand_xoshiro",
"space",
"thiserror 2.0.15",
]
[[package]]
name = "linfa-linalg"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02a834c0ec063937688a0d13573aa515ab8c425bd8de3154b908dd3b9c197dc4"
dependencies = [
"ndarray",
"num-traits",
"thiserror 1.0.69",
]
[[package]]
name = "linfa-nn"
version = "0.7.2"
source = "git+https://github.com/relf/linfa?branch=upgrade-ndarray-0.16#c1fbee7c54e806de3f5fb2c5240ce163d000f1ba"
dependencies = [
"kdtree",
"linfa",
"ndarray",
"ndarray-stats",
"noisy_float",
"num-traits",
"order-stat",
"thiserror 2.0.15",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.15"
@@ -3280,6 +3335,17 @@ dependencies = [
"zip",
]
[[package]]
name = "ndarray-rand"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f093b3db6fd194718dcdeea6bd8c829417deae904e3fcc7732dabcd4416d25d8"
dependencies = [
"ndarray",
"rand 0.8.5",
"rand_distr",
]
[[package]]
name = "ndarray-resize"
version = "0.1.0"
@@ -3303,6 +3369,21 @@ dependencies = [
"thiserror 2.0.15",
]
[[package]]
name = "ndarray-stats"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17ebbe97acce52d06aebed4cd4a87c0941f4b2519b59b82b4feb5bd0ce003dfd"
dependencies = [
"indexmap 2.10.0",
"itertools 0.13.0",
"ndarray",
"noisy_float",
"num-integer",
"num-traits",
"rand 0.8.5",
]
[[package]]
name = "ndcv-bridge"
version = "0.1.0"
@@ -3395,6 +3476,15 @@ dependencies = [
"memoffset",
]
[[package]]
name = "noisy_float"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978fe6e6ebc0bf53de533cd456ca2d9de13de13856eda1518a285d7705a213af"
dependencies = [
"num-traits",
]
[[package]]
name = "nom"
version = "7.1.3"
@@ -3889,6 +3979,12 @@ dependencies = [
"libredox",
]
[[package]]
name = "order-stat"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efa535d5117d3661134dbf1719b6f0ffe06f2375843b13935db186cd094105eb"
[[package]]
name = "ordered-float"
version = "5.0.0"
@@ -4437,6 +4533,15 @@ dependencies = [
"rand 0.8.5",
]
[[package]]
name = "rand_xoshiro"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa"
dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "range-alloc"
version = "0.1.4"
@@ -5108,6 +5213,12 @@ dependencies = [
"x11rb",
]
[[package]]
name = "space"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e990cc6cb89a82d70fe722cd7811dbce48a72bbfaebd623e58f142b6db28428f"
[[package]]
name = "spin"
version = "0.9.8"

View File

@@ -16,6 +16,7 @@ edition = "2024"
[patch.crates-io]
linfa = { git = "https://github.com/relf/linfa", branch = "upgrade-ndarray-0.16" }
linfa-clustering = { git = "https://github.com/relf/linfa", branch = "upgrade-ndarray-0.16" }
[workspace.dependencies]
divan = { version = "0.1.21" }
@@ -90,6 +91,7 @@ rfd = "0.15"
futures = "0.3"
imageproc = "0.25"
linfa = "0.7.1"
linfa-clustering = "0.7.1"
[profile.release]
debug = true

View File

@@ -21,6 +21,8 @@ pub enum SubCommand {
Stats(Stats),
#[clap(name = "compare")]
Compare(Compare),
#[clap(name = "cluster")]
Cluster(Cluster),
#[clap(name = "gui")]
Gui,
#[clap(name = "completions")]
@@ -187,6 +189,22 @@ pub struct Compare {
pub image2: PathBuf,
}
#[derive(Debug, clap::Args)]
pub struct Cluster {
#[clap(short = 'd', long, default_value = "face_detections.db")]
pub database: PathBuf,
#[clap(short, long, default_value_t = 5)]
pub clusters: usize,
#[clap(short, long, default_value_t = 100)]
pub max_iterations: u64,
#[clap(short, long, default_value_t = 1e-4)]
pub tolerance: f64,
#[clap(long, default_value = "facenet")]
pub model_name: String,
#[clap(short, long)]
pub output: Option<PathBuf>,
}
impl Cli {
pub fn completions(shell: clap_complete::Shell) {
let mut command = <Cli as clap::CommandFactory>::command();

View File

@@ -2,9 +2,11 @@ mod cli;
use bounding_box::roi::MultiRoi;
use detector::*;
use detector::{database::FaceDatabase, facedet, facedet::FaceDetectionConfig, faceembed};
use errors::*;
use fast_image_resize::ResizeOptions;
use linfa::prelude::*;
use linfa_clustering::KMeans;
use ndarray::*;
use ndarray_image::*;
use ndarray_resize::NdFir;
@@ -209,6 +211,9 @@ pub fn main() -> Result<()> {
std::process::exit(1);
}
}
cli::SubCommand::Cluster(cluster) => {
run_clustering(cluster)?;
}
cli::SubCommand::Completions { shell } => {
cli::Cli::completions(shell);
}
@@ -936,3 +941,123 @@ fn run_stats(stats: cli::Stats) -> Result<()> {
Ok(())
}
fn run_clustering(cluster: cli::Cluster) -> Result<()> {
let db = FaceDatabase::new(&cluster.database).change_context(Error)?;
// Get all embeddings for the specified model
let embeddings = db
.get_all_embeddings(Some(&cluster.model_name))
.change_context(Error)?;
if embeddings.is_empty() {
println!("No embeddings found for model '{}'", cluster.model_name);
return Ok(());
}
println!(
"Found {} embeddings for model '{}'",
embeddings.len(),
cluster.model_name
);
if embeddings.len() < cluster.clusters {
println!(
"Warning: Number of embeddings ({}) is less than requested clusters ({})",
embeddings.len(),
cluster.clusters
);
return Ok(());
}
// Convert embeddings to a 2D array for clustering
let embedding_dim = embeddings[0].embedding.len();
let mut data = Array2::<f64>::zeros((embeddings.len(), embedding_dim));
for (i, embedding_record) in embeddings.iter().enumerate() {
for (j, &val) in embedding_record.embedding.iter().enumerate() {
data[[i, j]] = val as f64;
}
}
println!(
"Running K-means clustering with {} clusters...",
cluster.clusters
);
// Create dataset
let dataset = linfa::Dataset::new(data, Array1::<usize>::zeros(embeddings.len()));
// Configure and run K-means
let model = KMeans::params(cluster.clusters)
.max_n_iterations(cluster.max_iterations)
.tolerance(cluster.tolerance)
.fit(&dataset)
.change_context(Error)
.attach_printable("Failed to run K-means clustering")?;
// Get cluster assignments
let predictions = model.predict(&dataset);
// Group results by cluster
let mut clusters: std::collections::HashMap<usize, Vec<(i64, String)>> =
std::collections::HashMap::new();
for (i, &cluster_id) in predictions.iter().enumerate() {
let face_id = embeddings[i].face_id;
// Get image path for this face
let image_info = db.get_image_for_face(face_id).change_context(Error)?;
let image_path = image_info
.map(|info| info.file_path)
.unwrap_or_else(|| "Unknown".to_string());
clusters
.entry(cluster_id)
.or_insert_with(Vec::new)
.push((face_id, image_path));
}
// Print results
println!("\nClustering Results:");
for cluster_id in 0..cluster.clusters {
if let Some(faces) = clusters.get(&cluster_id) {
println!("\nCluster {}: {} faces", cluster_id, faces.len());
for (face_id, image_path) in faces {
println!(" Face ID: {}, Image: {}", face_id, image_path);
}
} else {
println!("\nCluster {}: 0 faces", cluster_id);
}
}
// Optionally save results to file
if let Some(output_path) = &cluster.output {
use std::io::Write;
let mut file = std::fs::File::create(output_path)
.change_context(Error)
.attach_printable("Failed to create output file")?;
writeln!(file, "K-means Clustering Results").change_context(Error)?;
writeln!(file, "Model: {}", cluster.model_name).change_context(Error)?;
writeln!(file, "Total embeddings: {}", embeddings.len()).change_context(Error)?;
writeln!(file, "Number of clusters: {}", cluster.clusters).change_context(Error)?;
writeln!(file, "").change_context(Error)?;
for cluster_id in 0..cluster.clusters {
if let Some(faces) = clusters.get(&cluster_id) {
writeln!(file, "Cluster {}: {} faces", cluster_id, faces.len())
.change_context(Error)?;
for (face_id, image_path) in faces {
writeln!(file, " Face ID: {}, Image: {}", face_id, image_path)
.change_context(Error)?;
}
writeln!(file, "").change_context(Error)?;
}
}
println!("\nResults saved to: {:?}", output_path);
}
Ok(())
}

View File

@@ -65,14 +65,15 @@ impl FaceDatabase {
/// Create a new database connection and initialize tables
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
let conn = Connection::open(db_path).change_context(Error)?;
unsafe {
let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
conn.load_extension(
"/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
None::<&str>,
)
.change_context(Error)?;
}
// Temporarily disable extension loading for clustering demo
// unsafe {
// let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
// conn.load_extension(
// "/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
// None::<&str>,
// )
// .change_context(Error)?;
// }
let db = Self { conn };
db.create_tables()?;
Ok(db)
@@ -594,6 +595,76 @@ impl FaceDatabase {
println!("{:?}", result);
}
}
/// Get all embeddings for a specific model
pub fn get_all_embeddings(&self, model_name: Option<&str>) -> Result<Vec<EmbeddingRecord>> {
let mut embeddings = Vec::new();
if let Some(model) = model_name {
let mut stmt = self.conn.prepare(
"SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE model_name = ?1"
).change_context(Error)?;
let embedding_iter = stmt
.query_map(params![model], |row| {
let embedding_bytes: Vec<u8> = row.get(2)?;
let embedding: ndarray::Array1<f32> = {
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
.change_context(Error)
.unwrap();
sf.tensor::<f32, ndarray::Ix1>("embedding")
.unwrap()
.to_owned()
};
Ok(EmbeddingRecord {
id: row.get(0)?,
face_id: row.get(1)?,
embedding,
model_name: row.get(3)?,
created_at: row.get(4)?,
})
})
.change_context(Error)?;
for embedding in embedding_iter {
embeddings.push(embedding.change_context(Error)?);
}
} else {
let mut stmt = self
.conn
.prepare("SELECT id, face_id, embedding, model_name, created_at FROM embeddings")
.change_context(Error)?;
let embedding_iter = stmt
.query_map([], |row| {
let embedding_bytes: Vec<u8> = row.get(2)?;
let embedding: ndarray::Array1<f32> = {
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
.change_context(Error)
.unwrap();
sf.tensor::<f32, ndarray::Ix1>("embedding")
.unwrap()
.to_owned()
};
Ok(EmbeddingRecord {
id: row.get(0)?,
face_id: row.get(1)?,
embedding,
model_name: row.get(3)?,
created_at: row.get(4)?,
})
})
.change_context(Error)?;
for embedding in embedding_iter {
embeddings.push(embedding.change_context(Error)?);
}
}
Ok(embeddings)
}
}
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {