feat(cli): add clustering command with K-means support
This commit is contained in:
111
Cargo.lock
generated
111
Cargo.lock
generated
@@ -1272,6 +1272,7 @@ dependencies = [
|
||||
"imageproc",
|
||||
"itertools 0.14.0",
|
||||
"linfa",
|
||||
"linfa-clustering",
|
||||
"mnn",
|
||||
"mnn-bridge",
|
||||
"mnn-sync",
|
||||
@@ -2809,6 +2810,16 @@ dependencies = [
|
||||
"mutate_once",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kdtree"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0f0a0e9f770b65bac9aad00f97a67ab5c5319effed07f6da385da3c2115e47ba"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "khronos-egl"
|
||||
version = "6.0.0"
|
||||
@@ -2930,6 +2941,50 @@ dependencies = [
|
||||
"thiserror 2.0.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linfa-clustering"
|
||||
version = "0.7.1"
|
||||
source = "git+https://github.com/relf/linfa?branch=upgrade-ndarray-0.16#c1fbee7c54e806de3f5fb2c5240ce163d000f1ba"
|
||||
dependencies = [
|
||||
"linfa",
|
||||
"linfa-linalg",
|
||||
"linfa-nn",
|
||||
"ndarray",
|
||||
"ndarray-rand",
|
||||
"ndarray-stats",
|
||||
"noisy_float",
|
||||
"num-traits",
|
||||
"rand_xoshiro",
|
||||
"space",
|
||||
"thiserror 2.0.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linfa-linalg"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02a834c0ec063937688a0d13573aa515ab8c425bd8de3154b908dd3b9c197dc4"
|
||||
dependencies = [
|
||||
"ndarray",
|
||||
"num-traits",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linfa-nn"
|
||||
version = "0.7.2"
|
||||
source = "git+https://github.com/relf/linfa?branch=upgrade-ndarray-0.16#c1fbee7c54e806de3f5fb2c5240ce163d000f1ba"
|
||||
dependencies = [
|
||||
"kdtree",
|
||||
"linfa",
|
||||
"ndarray",
|
||||
"ndarray-stats",
|
||||
"noisy_float",
|
||||
"num-traits",
|
||||
"order-stat",
|
||||
"thiserror 2.0.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.15"
|
||||
@@ -3280,6 +3335,17 @@ dependencies = [
|
||||
"zip",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndarray-rand"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f093b3db6fd194718dcdeea6bd8c829417deae904e3fcc7732dabcd4416d25d8"
|
||||
dependencies = [
|
||||
"ndarray",
|
||||
"rand 0.8.5",
|
||||
"rand_distr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndarray-resize"
|
||||
version = "0.1.0"
|
||||
@@ -3303,6 +3369,21 @@ dependencies = [
|
||||
"thiserror 2.0.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndarray-stats"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "17ebbe97acce52d06aebed4cd4a87c0941f4b2519b59b82b4feb5bd0ce003dfd"
|
||||
dependencies = [
|
||||
"indexmap 2.10.0",
|
||||
"itertools 0.13.0",
|
||||
"ndarray",
|
||||
"noisy_float",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"rand 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndcv-bridge"
|
||||
version = "0.1.0"
|
||||
@@ -3395,6 +3476,15 @@ dependencies = [
|
||||
"memoffset",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "noisy_float"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "978fe6e6ebc0bf53de533cd456ca2d9de13de13856eda1518a285d7705a213af"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "7.1.3"
|
||||
@@ -3889,6 +3979,12 @@ dependencies = [
|
||||
"libredox",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "order-stat"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "efa535d5117d3661134dbf1719b6f0ffe06f2375843b13935db186cd094105eb"
|
||||
|
||||
[[package]]
|
||||
name = "ordered-float"
|
||||
version = "5.0.0"
|
||||
@@ -4437,6 +4533,15 @@ dependencies = [
|
||||
"rand 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_xoshiro"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa"
|
||||
dependencies = [
|
||||
"rand_core 0.6.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "range-alloc"
|
||||
version = "0.1.4"
|
||||
@@ -5108,6 +5213,12 @@ dependencies = [
|
||||
"x11rb",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "space"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e990cc6cb89a82d70fe722cd7811dbce48a72bbfaebd623e58f142b6db28428f"
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.9.8"
|
||||
|
||||
@@ -16,6 +16,7 @@ edition = "2024"
|
||||
|
||||
[patch.crates-io]
|
||||
linfa = { git = "https://github.com/relf/linfa", branch = "upgrade-ndarray-0.16" }
|
||||
linfa-clustering = { git = "https://github.com/relf/linfa", branch = "upgrade-ndarray-0.16" }
|
||||
|
||||
[workspace.dependencies]
|
||||
divan = { version = "0.1.21" }
|
||||
@@ -90,6 +91,7 @@ rfd = "0.15"
|
||||
futures = "0.3"
|
||||
imageproc = "0.25"
|
||||
linfa = "0.7.1"
|
||||
linfa-clustering = "0.7.1"
|
||||
|
||||
[profile.release]
|
||||
debug = true
|
||||
|
||||
@@ -21,6 +21,8 @@ pub enum SubCommand {
|
||||
Stats(Stats),
|
||||
#[clap(name = "compare")]
|
||||
Compare(Compare),
|
||||
#[clap(name = "cluster")]
|
||||
Cluster(Cluster),
|
||||
#[clap(name = "gui")]
|
||||
Gui,
|
||||
#[clap(name = "completions")]
|
||||
@@ -187,6 +189,22 @@ pub struct Compare {
|
||||
pub image2: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Cluster {
|
||||
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||
pub database: PathBuf,
|
||||
#[clap(short, long, default_value_t = 5)]
|
||||
pub clusters: usize,
|
||||
#[clap(short, long, default_value_t = 100)]
|
||||
pub max_iterations: u64,
|
||||
#[clap(short, long, default_value_t = 1e-4)]
|
||||
pub tolerance: f64,
|
||||
#[clap(long, default_value = "facenet")]
|
||||
pub model_name: String,
|
||||
#[clap(short, long)]
|
||||
pub output: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
pub fn completions(shell: clap_complete::Shell) {
|
||||
let mut command = <Cli as clap::CommandFactory>::command();
|
||||
|
||||
@@ -2,9 +2,11 @@ mod cli;
|
||||
use bounding_box::roi::MultiRoi;
|
||||
use detector::*;
|
||||
use detector::{database::FaceDatabase, facedet, facedet::FaceDetectionConfig, faceembed};
|
||||
use errors::*;
|
||||
|
||||
use fast_image_resize::ResizeOptions;
|
||||
|
||||
use linfa::prelude::*;
|
||||
use linfa_clustering::KMeans;
|
||||
use ndarray::*;
|
||||
use ndarray_image::*;
|
||||
use ndarray_resize::NdFir;
|
||||
@@ -209,6 +211,9 @@ pub fn main() -> Result<()> {
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
cli::SubCommand::Cluster(cluster) => {
|
||||
run_clustering(cluster)?;
|
||||
}
|
||||
cli::SubCommand::Completions { shell } => {
|
||||
cli::Cli::completions(shell);
|
||||
}
|
||||
@@ -936,3 +941,123 @@ fn run_stats(stats: cli::Stats) -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_clustering(cluster: cli::Cluster) -> Result<()> {
|
||||
let db = FaceDatabase::new(&cluster.database).change_context(Error)?;
|
||||
|
||||
// Get all embeddings for the specified model
|
||||
let embeddings = db
|
||||
.get_all_embeddings(Some(&cluster.model_name))
|
||||
.change_context(Error)?;
|
||||
|
||||
if embeddings.is_empty() {
|
||||
println!("No embeddings found for model '{}'", cluster.model_name);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!(
|
||||
"Found {} embeddings for model '{}'",
|
||||
embeddings.len(),
|
||||
cluster.model_name
|
||||
);
|
||||
|
||||
if embeddings.len() < cluster.clusters {
|
||||
println!(
|
||||
"Warning: Number of embeddings ({}) is less than requested clusters ({})",
|
||||
embeddings.len(),
|
||||
cluster.clusters
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Convert embeddings to a 2D array for clustering
|
||||
let embedding_dim = embeddings[0].embedding.len();
|
||||
let mut data = Array2::<f64>::zeros((embeddings.len(), embedding_dim));
|
||||
|
||||
for (i, embedding_record) in embeddings.iter().enumerate() {
|
||||
for (j, &val) in embedding_record.embedding.iter().enumerate() {
|
||||
data[[i, j]] = val as f64;
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"Running K-means clustering with {} clusters...",
|
||||
cluster.clusters
|
||||
);
|
||||
|
||||
// Create dataset
|
||||
let dataset = linfa::Dataset::new(data, Array1::<usize>::zeros(embeddings.len()));
|
||||
|
||||
// Configure and run K-means
|
||||
let model = KMeans::params(cluster.clusters)
|
||||
.max_n_iterations(cluster.max_iterations)
|
||||
.tolerance(cluster.tolerance)
|
||||
.fit(&dataset)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to run K-means clustering")?;
|
||||
|
||||
// Get cluster assignments
|
||||
let predictions = model.predict(&dataset);
|
||||
|
||||
// Group results by cluster
|
||||
let mut clusters: std::collections::HashMap<usize, Vec<(i64, String)>> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
for (i, &cluster_id) in predictions.iter().enumerate() {
|
||||
let face_id = embeddings[i].face_id;
|
||||
|
||||
// Get image path for this face
|
||||
let image_info = db.get_image_for_face(face_id).change_context(Error)?;
|
||||
let image_path = image_info
|
||||
.map(|info| info.file_path)
|
||||
.unwrap_or_else(|| "Unknown".to_string());
|
||||
|
||||
clusters
|
||||
.entry(cluster_id)
|
||||
.or_insert_with(Vec::new)
|
||||
.push((face_id, image_path));
|
||||
}
|
||||
|
||||
// Print results
|
||||
println!("\nClustering Results:");
|
||||
for cluster_id in 0..cluster.clusters {
|
||||
if let Some(faces) = clusters.get(&cluster_id) {
|
||||
println!("\nCluster {}: {} faces", cluster_id, faces.len());
|
||||
for (face_id, image_path) in faces {
|
||||
println!(" Face ID: {}, Image: {}", face_id, image_path);
|
||||
}
|
||||
} else {
|
||||
println!("\nCluster {}: 0 faces", cluster_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Optionally save results to file
|
||||
if let Some(output_path) = &cluster.output {
|
||||
use std::io::Write;
|
||||
let mut file = std::fs::File::create(output_path)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create output file")?;
|
||||
|
||||
writeln!(file, "K-means Clustering Results").change_context(Error)?;
|
||||
writeln!(file, "Model: {}", cluster.model_name).change_context(Error)?;
|
||||
writeln!(file, "Total embeddings: {}", embeddings.len()).change_context(Error)?;
|
||||
writeln!(file, "Number of clusters: {}", cluster.clusters).change_context(Error)?;
|
||||
writeln!(file, "").change_context(Error)?;
|
||||
|
||||
for cluster_id in 0..cluster.clusters {
|
||||
if let Some(faces) = clusters.get(&cluster_id) {
|
||||
writeln!(file, "Cluster {}: {} faces", cluster_id, faces.len())
|
||||
.change_context(Error)?;
|
||||
for (face_id, image_path) in faces {
|
||||
writeln!(file, " Face ID: {}, Image: {}", face_id, image_path)
|
||||
.change_context(Error)?;
|
||||
}
|
||||
writeln!(file, "").change_context(Error)?;
|
||||
}
|
||||
}
|
||||
|
||||
println!("\nResults saved to: {:?}", output_path);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -65,14 +65,15 @@ impl FaceDatabase {
|
||||
/// Create a new database connection and initialize tables
|
||||
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
|
||||
let conn = Connection::open(db_path).change_context(Error)?;
|
||||
unsafe {
|
||||
let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
|
||||
conn.load_extension(
|
||||
"/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
|
||||
None::<&str>,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
}
|
||||
// Temporarily disable extension loading for clustering demo
|
||||
// unsafe {
|
||||
// let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
|
||||
// conn.load_extension(
|
||||
// "/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
|
||||
// None::<&str>,
|
||||
// )
|
||||
// .change_context(Error)?;
|
||||
// }
|
||||
let db = Self { conn };
|
||||
db.create_tables()?;
|
||||
Ok(db)
|
||||
@@ -594,6 +595,76 @@ impl FaceDatabase {
|
||||
println!("{:?}", result);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all embeddings for a specific model
|
||||
pub fn get_all_embeddings(&self, model_name: Option<&str>) -> Result<Vec<EmbeddingRecord>> {
|
||||
let mut embeddings = Vec::new();
|
||||
|
||||
if let Some(model) = model_name {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE model_name = ?1"
|
||||
).change_context(Error)?;
|
||||
|
||||
let embedding_iter = stmt
|
||||
.query_map(params![model], |row| {
|
||||
let embedding_bytes: Vec<u8> = row.get(2)?;
|
||||
let embedding: ndarray::Array1<f32> = {
|
||||
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
sf.tensor::<f32, ndarray::Ix1>("embedding")
|
||||
.unwrap()
|
||||
.to_owned()
|
||||
};
|
||||
|
||||
Ok(EmbeddingRecord {
|
||||
id: row.get(0)?,
|
||||
face_id: row.get(1)?,
|
||||
embedding,
|
||||
model_name: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})
|
||||
.change_context(Error)?;
|
||||
|
||||
for embedding in embedding_iter {
|
||||
embeddings.push(embedding.change_context(Error)?);
|
||||
}
|
||||
} else {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("SELECT id, face_id, embedding, model_name, created_at FROM embeddings")
|
||||
.change_context(Error)?;
|
||||
|
||||
let embedding_iter = stmt
|
||||
.query_map([], |row| {
|
||||
let embedding_bytes: Vec<u8> = row.get(2)?;
|
||||
let embedding: ndarray::Array1<f32> = {
|
||||
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
sf.tensor::<f32, ndarray::Ix1>("embedding")
|
||||
.unwrap()
|
||||
.to_owned()
|
||||
};
|
||||
|
||||
Ok(EmbeddingRecord {
|
||||
id: row.get(0)?,
|
||||
face_id: row.get(1)?,
|
||||
embedding,
|
||||
model_name: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})
|
||||
.change_context(Error)?;
|
||||
|
||||
for embedding in embedding_iter {
|
||||
embeddings.push(embedding.change_context(Error)?);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
|
||||
|
||||
Reference in New Issue
Block a user