feat(cli): add clustering command with K-means support
This commit is contained in:
@@ -21,6 +21,8 @@ pub enum SubCommand {
|
||||
Stats(Stats),
|
||||
#[clap(name = "compare")]
|
||||
Compare(Compare),
|
||||
#[clap(name = "cluster")]
|
||||
Cluster(Cluster),
|
||||
#[clap(name = "gui")]
|
||||
Gui,
|
||||
#[clap(name = "completions")]
|
||||
@@ -187,6 +189,22 @@ pub struct Compare {
|
||||
pub image2: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Cluster {
|
||||
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||
pub database: PathBuf,
|
||||
#[clap(short, long, default_value_t = 5)]
|
||||
pub clusters: usize,
|
||||
#[clap(short, long, default_value_t = 100)]
|
||||
pub max_iterations: u64,
|
||||
#[clap(short, long, default_value_t = 1e-4)]
|
||||
pub tolerance: f64,
|
||||
#[clap(long, default_value = "facenet")]
|
||||
pub model_name: String,
|
||||
#[clap(short, long)]
|
||||
pub output: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
pub fn completions(shell: clap_complete::Shell) {
|
||||
let mut command = <Cli as clap::CommandFactory>::command();
|
||||
|
||||
@@ -2,9 +2,11 @@ mod cli;
|
||||
use bounding_box::roi::MultiRoi;
|
||||
use detector::*;
|
||||
use detector::{database::FaceDatabase, facedet, facedet::FaceDetectionConfig, faceembed};
|
||||
use errors::*;
|
||||
|
||||
use fast_image_resize::ResizeOptions;
|
||||
|
||||
use linfa::prelude::*;
|
||||
use linfa_clustering::KMeans;
|
||||
use ndarray::*;
|
||||
use ndarray_image::*;
|
||||
use ndarray_resize::NdFir;
|
||||
@@ -209,6 +211,9 @@ pub fn main() -> Result<()> {
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
cli::SubCommand::Cluster(cluster) => {
|
||||
run_clustering(cluster)?;
|
||||
}
|
||||
cli::SubCommand::Completions { shell } => {
|
||||
cli::Cli::completions(shell);
|
||||
}
|
||||
@@ -936,3 +941,123 @@ fn run_stats(stats: cli::Stats) -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_clustering(cluster: cli::Cluster) -> Result<()> {
|
||||
let db = FaceDatabase::new(&cluster.database).change_context(Error)?;
|
||||
|
||||
// Get all embeddings for the specified model
|
||||
let embeddings = db
|
||||
.get_all_embeddings(Some(&cluster.model_name))
|
||||
.change_context(Error)?;
|
||||
|
||||
if embeddings.is_empty() {
|
||||
println!("No embeddings found for model '{}'", cluster.model_name);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!(
|
||||
"Found {} embeddings for model '{}'",
|
||||
embeddings.len(),
|
||||
cluster.model_name
|
||||
);
|
||||
|
||||
if embeddings.len() < cluster.clusters {
|
||||
println!(
|
||||
"Warning: Number of embeddings ({}) is less than requested clusters ({})",
|
||||
embeddings.len(),
|
||||
cluster.clusters
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Convert embeddings to a 2D array for clustering
|
||||
let embedding_dim = embeddings[0].embedding.len();
|
||||
let mut data = Array2::<f64>::zeros((embeddings.len(), embedding_dim));
|
||||
|
||||
for (i, embedding_record) in embeddings.iter().enumerate() {
|
||||
for (j, &val) in embedding_record.embedding.iter().enumerate() {
|
||||
data[[i, j]] = val as f64;
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"Running K-means clustering with {} clusters...",
|
||||
cluster.clusters
|
||||
);
|
||||
|
||||
// Create dataset
|
||||
let dataset = linfa::Dataset::new(data, Array1::<usize>::zeros(embeddings.len()));
|
||||
|
||||
// Configure and run K-means
|
||||
let model = KMeans::params(cluster.clusters)
|
||||
.max_n_iterations(cluster.max_iterations)
|
||||
.tolerance(cluster.tolerance)
|
||||
.fit(&dataset)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to run K-means clustering")?;
|
||||
|
||||
// Get cluster assignments
|
||||
let predictions = model.predict(&dataset);
|
||||
|
||||
// Group results by cluster
|
||||
let mut clusters: std::collections::HashMap<usize, Vec<(i64, String)>> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
for (i, &cluster_id) in predictions.iter().enumerate() {
|
||||
let face_id = embeddings[i].face_id;
|
||||
|
||||
// Get image path for this face
|
||||
let image_info = db.get_image_for_face(face_id).change_context(Error)?;
|
||||
let image_path = image_info
|
||||
.map(|info| info.file_path)
|
||||
.unwrap_or_else(|| "Unknown".to_string());
|
||||
|
||||
clusters
|
||||
.entry(cluster_id)
|
||||
.or_insert_with(Vec::new)
|
||||
.push((face_id, image_path));
|
||||
}
|
||||
|
||||
// Print results
|
||||
println!("\nClustering Results:");
|
||||
for cluster_id in 0..cluster.clusters {
|
||||
if let Some(faces) = clusters.get(&cluster_id) {
|
||||
println!("\nCluster {}: {} faces", cluster_id, faces.len());
|
||||
for (face_id, image_path) in faces {
|
||||
println!(" Face ID: {}, Image: {}", face_id, image_path);
|
||||
}
|
||||
} else {
|
||||
println!("\nCluster {}: 0 faces", cluster_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Optionally save results to file
|
||||
if let Some(output_path) = &cluster.output {
|
||||
use std::io::Write;
|
||||
let mut file = std::fs::File::create(output_path)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create output file")?;
|
||||
|
||||
writeln!(file, "K-means Clustering Results").change_context(Error)?;
|
||||
writeln!(file, "Model: {}", cluster.model_name).change_context(Error)?;
|
||||
writeln!(file, "Total embeddings: {}", embeddings.len()).change_context(Error)?;
|
||||
writeln!(file, "Number of clusters: {}", cluster.clusters).change_context(Error)?;
|
||||
writeln!(file, "").change_context(Error)?;
|
||||
|
||||
for cluster_id in 0..cluster.clusters {
|
||||
if let Some(faces) = clusters.get(&cluster_id) {
|
||||
writeln!(file, "Cluster {}: {} faces", cluster_id, faces.len())
|
||||
.change_context(Error)?;
|
||||
for (face_id, image_path) in faces {
|
||||
writeln!(file, " Face ID: {}, Image: {}", face_id, image_path)
|
||||
.change_context(Error)?;
|
||||
}
|
||||
writeln!(file, "").change_context(Error)?;
|
||||
}
|
||||
}
|
||||
|
||||
println!("\nResults saved to: {:?}", output_path);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -65,14 +65,15 @@ impl FaceDatabase {
|
||||
/// Create a new database connection and initialize tables
|
||||
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
|
||||
let conn = Connection::open(db_path).change_context(Error)?;
|
||||
unsafe {
|
||||
let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
|
||||
conn.load_extension(
|
||||
"/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
|
||||
None::<&str>,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
}
|
||||
// Temporarily disable extension loading for clustering demo
|
||||
// unsafe {
|
||||
// let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
|
||||
// conn.load_extension(
|
||||
// "/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
|
||||
// None::<&str>,
|
||||
// )
|
||||
// .change_context(Error)?;
|
||||
// }
|
||||
let db = Self { conn };
|
||||
db.create_tables()?;
|
||||
Ok(db)
|
||||
@@ -594,6 +595,76 @@ impl FaceDatabase {
|
||||
println!("{:?}", result);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all embeddings for a specific model
|
||||
pub fn get_all_embeddings(&self, model_name: Option<&str>) -> Result<Vec<EmbeddingRecord>> {
|
||||
let mut embeddings = Vec::new();
|
||||
|
||||
if let Some(model) = model_name {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE model_name = ?1"
|
||||
).change_context(Error)?;
|
||||
|
||||
let embedding_iter = stmt
|
||||
.query_map(params![model], |row| {
|
||||
let embedding_bytes: Vec<u8> = row.get(2)?;
|
||||
let embedding: ndarray::Array1<f32> = {
|
||||
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
sf.tensor::<f32, ndarray::Ix1>("embedding")
|
||||
.unwrap()
|
||||
.to_owned()
|
||||
};
|
||||
|
||||
Ok(EmbeddingRecord {
|
||||
id: row.get(0)?,
|
||||
face_id: row.get(1)?,
|
||||
embedding,
|
||||
model_name: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})
|
||||
.change_context(Error)?;
|
||||
|
||||
for embedding in embedding_iter {
|
||||
embeddings.push(embedding.change_context(Error)?);
|
||||
}
|
||||
} else {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("SELECT id, face_id, embedding, model_name, created_at FROM embeddings")
|
||||
.change_context(Error)?;
|
||||
|
||||
let embedding_iter = stmt
|
||||
.query_map([], |row| {
|
||||
let embedding_bytes: Vec<u8> = row.get(2)?;
|
||||
let embedding: ndarray::Array1<f32> = {
|
||||
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
sf.tensor::<f32, ndarray::Ix1>("embedding")
|
||||
.unwrap()
|
||||
.to_owned()
|
||||
};
|
||||
|
||||
Ok(EmbeddingRecord {
|
||||
id: row.get(0)?,
|
||||
face_id: row.get(1)?,
|
||||
embedding,
|
||||
model_name: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})
|
||||
.change_context(Error)?;
|
||||
|
||||
for embedding in embedding_iter {
|
||||
embeddings.push(embedding.change_context(Error)?);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
|
||||
|
||||
Reference in New Issue
Block a user