Compare commits
3 Commits
61466c9edd
...
97f64e7e10
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
97f64e7e10 | ||
|
|
37adb74adf | ||
|
|
47218fa696 |
48
Cargo.lock
generated
48
Cargo.lock
generated
@@ -269,6 +269,20 @@ name = "bytemuck"
|
|||||||
version = "1.23.2"
|
version = "1.23.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3995eaeebcdf32f91f980d360f78732ddc061097ab4e39991ae7a6ace9194677"
|
checksum = "3995eaeebcdf32f91f980d360f78732ddc061097ab4e39991ae7a6ace9194677"
|
||||||
|
dependencies = [
|
||||||
|
"bytemuck_derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bytemuck_derive"
|
||||||
|
version = "1.10.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4f154e572231cb6ba2bd1176980827e3d5dc04cc183a75dea38109fbdd672d29"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "byteorder-lite"
|
name = "byteorder-lite"
|
||||||
@@ -504,7 +518,9 @@ dependencies = [
|
|||||||
"nalgebra",
|
"nalgebra",
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"ndarray-image",
|
"ndarray-image",
|
||||||
|
"ndarray-math",
|
||||||
"ndarray-resize",
|
"ndarray-resize",
|
||||||
|
"ndarray-safetensors",
|
||||||
"ordered-float",
|
"ordered-float",
|
||||||
"ort",
|
"ort",
|
||||||
"rusqlite",
|
"rusqlite",
|
||||||
@@ -830,6 +846,7 @@ version = "2.6.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9"
|
checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"bytemuck",
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"crunchy",
|
"crunchy",
|
||||||
]
|
]
|
||||||
@@ -1414,6 +1431,16 @@ dependencies = [
|
|||||||
"ndarray",
|
"ndarray",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ndarray-math"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "git+https://git.darksailor.dev/servius/ndarray-math#f047966f20835267f20e5839272b9ab36c445796"
|
||||||
|
dependencies = [
|
||||||
|
"ndarray",
|
||||||
|
"num",
|
||||||
|
"thiserror 2.0.15",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ndarray-resize"
|
name = "ndarray-resize"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@@ -1426,6 +1453,17 @@ dependencies = [
|
|||||||
"thiserror 2.0.15",
|
"thiserror 2.0.15",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ndarray-safetensors"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"bytemuck",
|
||||||
|
"half",
|
||||||
|
"ndarray",
|
||||||
|
"safetensors",
|
||||||
|
"thiserror 2.0.15",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "new_debug_unreachable"
|
name = "new_debug_unreachable"
|
||||||
version = "1.0.6"
|
version = "1.0.6"
|
||||||
@@ -1983,6 +2021,16 @@ dependencies = [
|
|||||||
"bytemuck",
|
"bytemuck",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "safetensors"
|
||||||
|
version = "0.6.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "172dd94c5a87b5c79f945c863da53b2ebc7ccef4eca24ac63cca66a41aab2178"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "scopeguard"
|
name = "scopeguard"
|
||||||
version = "1.2.0"
|
version = "1.2.0"
|
||||||
|
|||||||
12
Cargo.toml
12
Cargo.toml
@@ -1,5 +1,5 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = ["ndarray-image", "ndarray-resize", ".", "bounding-box"]
|
members = ["ndarray-image", "ndarray-resize", ".", "bounding-box", "ndarray-safetensors"]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@@ -37,7 +37,7 @@ nalgebra = { workspace = true }
|
|||||||
ndarray = "0.16.1"
|
ndarray = "0.16.1"
|
||||||
ndarray-image = { workspace = true }
|
ndarray-image = { workspace = true }
|
||||||
ndarray-resize = { workspace = true }
|
ndarray-resize = { workspace = true }
|
||||||
rusqlite = { version = "0.37.0", features = ["modern-full"] }
|
rusqlite = { version = "0.37.0", features = ["functions", "modern-full"] }
|
||||||
tap = "1.0.1"
|
tap = "1.0.1"
|
||||||
thiserror = "2.0"
|
thiserror = "2.0"
|
||||||
tokio = "1.43.1"
|
tokio = "1.43.1"
|
||||||
@@ -50,11 +50,9 @@ bounding-box = { version = "0.1.0", path = "bounding-box" }
|
|||||||
color = "0.3.1"
|
color = "0.3.1"
|
||||||
itertools = "0.14.0"
|
itertools = "0.14.0"
|
||||||
ordered-float = "5.0.0"
|
ordered-float = "5.0.0"
|
||||||
ort = { version = "2.0.0-rc.10", default-features = false, features = [
|
ort = { version = "2.0.0-rc.10", default-features = false, features = [ "std", "tracing", "ndarray"]}
|
||||||
"std",
|
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
|
||||||
"tracing",
|
ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" }
|
||||||
"ndarray",
|
|
||||||
] }
|
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
debug = true
|
debug = true
|
||||||
|
|||||||
@@ -114,16 +114,17 @@
|
|||||||
stdenv = p: p.clangStdenv;
|
stdenv = p: p.clangStdenv;
|
||||||
doCheck = false;
|
doCheck = false;
|
||||||
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
|
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
|
||||||
ORT_LIB_LOCATION = "${patchedOnnxruntime}";
|
# ORT_LIB_LOCATION = "${patchedOnnxruntime}";
|
||||||
ORT_ENV_SYSTEM_LIB_LOCATION = "${patchedOnnxruntime}/lib";
|
# ORT_ENV_SYSTEM_LIB_LOCATION = "${patchedOnnxruntime}/lib";
|
||||||
ORT_ENV_PREFER_DYNAMIC_LINK = true;
|
# ORT_ENV_PREFER_DYNAMIC_LINK = true;
|
||||||
nativeBuildInputs = with pkgs; [
|
nativeBuildInputs = with pkgs; [
|
||||||
cmake
|
cmake
|
||||||
pkg-config
|
pkg-config
|
||||||
];
|
];
|
||||||
buildInputs = with pkgs;
|
buildInputs = with pkgs;
|
||||||
[
|
[
|
||||||
# onnxruntime
|
patchedOnnxruntime
|
||||||
|
sqlite
|
||||||
]
|
]
|
||||||
++ (lib.optionals pkgs.stdenv.isDarwin [
|
++ (lib.optionals pkgs.stdenv.isDarwin [
|
||||||
libiconv
|
libiconv
|
||||||
|
|||||||
4
justfile
4
justfile
@@ -9,5 +9,5 @@ open:
|
|||||||
bench:
|
bench:
|
||||||
cargo build --release
|
cargo build --release
|
||||||
BINARY="" hyperfine --warmup 3 --export-markdown benchmark.md \
|
BINARY="" hyperfine --warmup 3 --export-markdown benchmark.md \
|
||||||
"$CARGO_TARGET_DIR/release/detector detect -f coreml selfie.jpg" \
|
"$CARGO_TARGET_DIR/release/detector detect -f cpu selfie.jpg" \
|
||||||
"$CARGO_TARGET_DIR/release/detector detect -f coreml -b 16 selfie.jpg"
|
"$CARGO_TARGET_DIR/release/detector detect -f cpu -b 1 selfie.jpg"
|
||||||
|
|||||||
11
ndarray-safetensors/Cargo.toml
Normal file
11
ndarray-safetensors/Cargo.toml
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
[package]
|
||||||
|
name = "ndarray-safetensors"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
bytemuck = { version = "1.23.2" }
|
||||||
|
half = { version = "2.6.0", default-features = false, features = ["bytemuck"] }
|
||||||
|
ndarray = { version = "0.16.1", default-features = false, features = ["std"] }
|
||||||
|
safetensors = "0.6.2"
|
||||||
|
thiserror = "2.0.15"
|
||||||
432
ndarray-safetensors/src/lib.rs
Normal file
432
ndarray-safetensors/src/lib.rs
Normal file
@@ -0,0 +1,432 @@
|
|||||||
|
//! # ndarray-serialize
|
||||||
|
//!
|
||||||
|
//! A Rust library for serializing and deserializing `ndarray` arrays using the SafeTensors format.
|
||||||
|
//!
|
||||||
|
//! ## Features
|
||||||
|
//! - Serialize `ndarray::ArrayView` to SafeTensors format
|
||||||
|
//! - Deserialize SafeTensors data back to `ndarray::ArrayView`
|
||||||
|
//! - Support for multiple data types (f32, f64, i8-i64, u8-u64, f16, bf16)
|
||||||
|
//! - Zero-copy deserialization when possible
|
||||||
|
//! - Metadata support
|
||||||
|
//!
|
||||||
|
//! ## Example
|
||||||
|
//! ```rust
|
||||||
|
//! use ndarray::Array2;
|
||||||
|
//! use ndarray_safetensors::{SafeArrays, SafeArrayView};
|
||||||
|
//!
|
||||||
|
//! // Create some data
|
||||||
|
//! let array = Array2::<f32>::zeros((3, 4));
|
||||||
|
//!
|
||||||
|
//! // Serialize
|
||||||
|
//! let mut safe_arrays = SafeArrays::new();
|
||||||
|
//! safe_arrays.insert_ndarray("my_tensor", array.view()).unwrap();
|
||||||
|
//! safe_arrays.insert_metadata("author", "example");
|
||||||
|
//! let bytes = safe_arrays.serialize().unwrap();
|
||||||
|
//!
|
||||||
|
//! // Deserialize
|
||||||
|
//! let view = SafeArrayView::from_bytes(&bytes).unwrap();
|
||||||
|
//! let tensor: ndarray::ArrayView2<f32> = view.tensor("my_tensor").unwrap();
|
||||||
|
//! assert_eq!(tensor.shape(), &[3, 4]);
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use safetensors::View;
|
||||||
|
use std::borrow::Cow;
|
||||||
|
use std::collections::{BTreeMap, HashMap};
|
||||||
|
|
||||||
|
use thiserror::Error;
|
||||||
|
/// Errors that can occur during SafeTensor operations
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum SafeTensorError {
|
||||||
|
#[error("Tensor not found: {0}")]
|
||||||
|
TensorNotFound(String),
|
||||||
|
#[error("Invalid tensor data: Got {0} Expected: {1}")]
|
||||||
|
InvalidTensorData(&'static str, String),
|
||||||
|
#[error("IO error: {0}")]
|
||||||
|
IoError(#[from] std::io::Error),
|
||||||
|
#[error("Safetensor error: {0}")]
|
||||||
|
SafeTensor(#[from] safetensors::SafeTensorError),
|
||||||
|
#[error("ndarray::ShapeError error: {0}")]
|
||||||
|
NdarrayShapeError(#[from] ndarray::ShapeError),
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result<T, E = SafeTensorError> = core::result::Result<T, E>;
|
||||||
|
|
||||||
|
use safetensors::tensor::SafeTensors;
|
||||||
|
|
||||||
|
/// A view into SafeTensors data that provides access to ndarray tensors
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```rust
|
||||||
|
/// use ndarray::Array2;
|
||||||
|
/// use ndarray_safetensors::{SafeArrays, SafeArrayView};
|
||||||
|
///
|
||||||
|
/// let array = Array2::<f32>::ones((2, 3));
|
||||||
|
/// let mut safe_arrays = SafeArrays::new();
|
||||||
|
/// safe_arrays.insert_ndarray("data", array.view()).unwrap();
|
||||||
|
/// let bytes = safe_arrays.serialize().unwrap();
|
||||||
|
///
|
||||||
|
/// let view = SafeArrayView::from_bytes(&bytes).unwrap();
|
||||||
|
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
|
||||||
|
/// ```
|
||||||
|
pub struct SafeArraysView<'a> {
|
||||||
|
pub tensors: SafeTensors<'a>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> SafeArraysView<'a> {
|
||||||
|
fn new(tensors: SafeTensors<'a>) -> Self {
|
||||||
|
Self { tensors }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a SafeArrayView from serialized bytes
|
||||||
|
pub fn from_bytes(bytes: &'a [u8]) -> Result<SafeArraysView<'a>> {
|
||||||
|
let tensors = SafeTensors::deserialize(bytes)?;
|
||||||
|
Ok(Self::new(tensors))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a dynamic-dimensional tensor by name
|
||||||
|
pub fn dynamic_tensor<T: STDtype>(&self, name: &str) -> Result<ndarray::ArrayViewD<'a, T>> {
|
||||||
|
self.tensors
|
||||||
|
.tensor(name)
|
||||||
|
.map(|tensor| tensor_view_to_array_view(tensor))?
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a tensor with specific dimensions by name
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```rust
|
||||||
|
/// # use ndarray::Array2;
|
||||||
|
/// # use ndarray_safetensors::{SafeArrays, SafeArrayView};
|
||||||
|
/// # let array = Array2::<f32>::ones((2, 3));
|
||||||
|
/// # let mut safe_arrays = SafeArrays::new();
|
||||||
|
/// # safe_arrays.insert_ndarray("data", array.view()).unwrap();
|
||||||
|
/// # let bytes = safe_arrays.serialize().unwrap();
|
||||||
|
/// # let view = SafeArrayView::from_bytes(&bytes).unwrap();
|
||||||
|
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
|
||||||
|
/// ```
|
||||||
|
pub fn tensor<T: STDtype, Dim: ndarray::Dimension>(
|
||||||
|
&self,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<ndarray::ArrayView<'a, T, Dim>> {
|
||||||
|
Ok(self
|
||||||
|
.tensors
|
||||||
|
.tensor(name)
|
||||||
|
.map(|tensor| tensor_view_to_array_view(tensor))?
|
||||||
|
.map(|array_view| array_view.into_dimensionality::<Dim>())??)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get an iterator over tensor names
|
||||||
|
pub fn names(&self) -> std::vec::IntoIter<&str> {
|
||||||
|
self.tensors.names().into_iter()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the number of tensors
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.tensors.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if there are no tensors
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.tensors.is_empty()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trait for types that can be stored in SafeTensors
|
||||||
|
///
|
||||||
|
/// Implemented for: f32, f64, i8, i16, i32, i64, u8, u16, u32, u64, f16, bf16
|
||||||
|
pub trait STDtype: bytemuck::Pod {
|
||||||
|
fn dtype() -> safetensors::tensor::Dtype;
|
||||||
|
fn size() -> usize {
|
||||||
|
(Self::dtype().bitsize() / 8).max(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! impl_dtype {
|
||||||
|
($($t:ty => $dtype:expr),* $(,)?) => {
|
||||||
|
$(
|
||||||
|
impl STDtype for $t {
|
||||||
|
fn dtype() -> safetensors::tensor::Dtype {
|
||||||
|
$dtype
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)*
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
use safetensors::tensor::Dtype;
|
||||||
|
|
||||||
|
impl_dtype!(
|
||||||
|
// bool => Dtype::BOOL, // idk if ndarray::ArrayD<bool> is packed
|
||||||
|
f32 => Dtype::F32,
|
||||||
|
f64 => Dtype::F64,
|
||||||
|
i8 => Dtype::I8,
|
||||||
|
i16 => Dtype::I16,
|
||||||
|
i32 => Dtype::I32,
|
||||||
|
i64 => Dtype::I64,
|
||||||
|
u8 => Dtype::U8,
|
||||||
|
u16 => Dtype::U16,
|
||||||
|
u32 => Dtype::U32,
|
||||||
|
u64 => Dtype::U64,
|
||||||
|
half::f16 => Dtype::F16,
|
||||||
|
half::bf16 => Dtype::BF16,
|
||||||
|
);
|
||||||
|
|
||||||
|
fn tensor_view_to_array_view<'a, T: STDtype>(
|
||||||
|
tensor: safetensors::tensor::TensorView<'a>,
|
||||||
|
) -> Result<ndarray::ArrayViewD<'a, T>> {
|
||||||
|
let shape = tensor.shape();
|
||||||
|
let dtype = tensor.dtype();
|
||||||
|
if T::dtype() != dtype {
|
||||||
|
return Err(SafeTensorError::InvalidTensorData(
|
||||||
|
core::any::type_name::<T>(),
|
||||||
|
dtype.to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = tensor.data();
|
||||||
|
let data: &[T] = bytemuck::cast_slice(data);
|
||||||
|
let array = ndarray::ArrayViewD::from_shape(shape, data)?;
|
||||||
|
Ok(array)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builder for creating SafeTensors data from ndarray tensors
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```rust
|
||||||
|
/// use ndarray::{Array1, Array2};
|
||||||
|
/// use ndarray_safetensors::SafeArrays;
|
||||||
|
///
|
||||||
|
/// let mut safe_arrays = SafeArrays::new();
|
||||||
|
///
|
||||||
|
/// let array1 = Array1::<f32>::from_vec(vec![1.0, 2.0, 3.0]);
|
||||||
|
/// let array2 = Array2::<i32>::zeros((2, 2));
|
||||||
|
///
|
||||||
|
/// safe_arrays.insert_ndarray("vector", array1.view()).unwrap();
|
||||||
|
/// safe_arrays.insert_ndarray("matrix", array2.view()).unwrap();
|
||||||
|
/// safe_arrays.insert_metadata("version", "1.0");
|
||||||
|
///
|
||||||
|
/// let bytes = safe_arrays.serialize().unwrap();
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub struct SafeArrays<'a> {
|
||||||
|
pub tensors: BTreeMap<String, SafeArray<'a>>,
|
||||||
|
pub metadata: Option<HashMap<String, String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, K: AsRef<str>> FromIterator<(K, SafeArray<'a>)> for SafeArrays<'a> {
|
||||||
|
fn from_iter<T: IntoIterator<Item = (K, SafeArray<'a>)>>(iter: T) -> Self {
|
||||||
|
let tensors = iter
|
||||||
|
.into_iter()
|
||||||
|
.map(|(k, v)| (k.as_ref().to_owned(), v))
|
||||||
|
.collect();
|
||||||
|
Self {
|
||||||
|
tensors,
|
||||||
|
metadata: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, K: AsRef<str>, T: IntoIterator<Item = (K, SafeArray<'a>)>> From<T> for SafeArrays<'a> {
|
||||||
|
fn from(iter: T) -> Self {
|
||||||
|
let tensors = iter
|
||||||
|
.into_iter()
|
||||||
|
.map(|(k, v)| (k.as_ref().to_owned(), v))
|
||||||
|
.collect();
|
||||||
|
Self {
|
||||||
|
tensors,
|
||||||
|
metadata: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> SafeArrays<'a> {
|
||||||
|
/// Create a SafeArrays from an iterator of (name, ndarray::ArrayView) pairs
|
||||||
|
/// ```rust
|
||||||
|
/// use ndarray::{Array2, Array3};
|
||||||
|
/// use ndarray_safetensors::{SafeArrays, SafeArray};
|
||||||
|
/// let array = Array2::<f32>::zeros((3, 4));
|
||||||
|
/// let safe_arrays = SafeArrays::from_ndarrays(vec![
|
||||||
|
/// ("test_tensor", array.view()),
|
||||||
|
/// ("test_tensor2", array.view()),
|
||||||
|
/// ]).unwrap();
|
||||||
|
/// ```
|
||||||
|
|
||||||
|
pub fn from_ndarrays<
|
||||||
|
K: AsRef<str>,
|
||||||
|
T: STDtype,
|
||||||
|
D: ndarray::Dimension + 'a,
|
||||||
|
I: IntoIterator<Item = (K, ndarray::ArrayView<'a, T, D>)>,
|
||||||
|
>(
|
||||||
|
iter: I,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let tensors = iter
|
||||||
|
.into_iter()
|
||||||
|
.map(|(k, v)| Ok((k.as_ref().to_owned(), SafeArray::from_ndarray(v)?)))
|
||||||
|
.collect::<Result<BTreeMap<String, SafeArray<'a>>>>()?;
|
||||||
|
Ok(Self {
|
||||||
|
tensors,
|
||||||
|
metadata: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// impl<'a, K: AsRef<str>, T: IntoIterator<Item = (K, SafeArray<'a>)>> From<T> for SafeArrays<'a> {
|
||||||
|
// fn from(iter: T) -> Self {
|
||||||
|
// let tensors = iter
|
||||||
|
// .into_iter()
|
||||||
|
// .map(|(k, v)| (k.as_ref().to_owned(), v))
|
||||||
|
// .collect();
|
||||||
|
// Self {
|
||||||
|
// tensors,
|
||||||
|
// metadata: None,
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
impl<'a> SafeArrays<'a> {
|
||||||
|
/// Create a new empty SafeArrays builder
|
||||||
|
pub const fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
tensors: BTreeMap::new(),
|
||||||
|
metadata: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert a SafeArray tensor with the given name
|
||||||
|
pub fn insert_tensor<'b: 'a>(&mut self, name: impl AsRef<str>, tensor: SafeArray<'b>) {
|
||||||
|
self.tensors.insert(name.as_ref().to_owned(), tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert an ndarray tensor with the given name
|
||||||
|
///
|
||||||
|
/// The array must be in standard layout and contiguous.
|
||||||
|
pub fn insert_ndarray<'b: 'a, T: STDtype, D: ndarray::Dimension + 'a>(
|
||||||
|
&mut self,
|
||||||
|
name: impl AsRef<str>,
|
||||||
|
array: ndarray::ArrayView<'b, T, D>,
|
||||||
|
) -> Result<()> {
|
||||||
|
self.insert_tensor(name, SafeArray::from_ndarray(array)?);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert metadata key-value pair
|
||||||
|
pub fn insert_metadata(&mut self, key: impl AsRef<str>, value: impl AsRef<str>) {
|
||||||
|
self.metadata
|
||||||
|
.get_or_insert_default()
|
||||||
|
.insert(key.as_ref().to_owned(), value.as_ref().to_owned());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Serialize all tensors and metadata to bytes
|
||||||
|
pub fn serialize(self) -> Result<Vec<u8>> {
|
||||||
|
let out = safetensors::serialize(self.tensors, self.metadata)
|
||||||
|
.map_err(SafeTensorError::SafeTensor)?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A tensor that can be serialized to SafeTensors format
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct SafeArray<'a> {
|
||||||
|
data: Cow<'a, [u8]>,
|
||||||
|
shape: Vec<usize>,
|
||||||
|
dtype: safetensors::tensor::Dtype,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl View for SafeArray<'_> {
|
||||||
|
fn dtype(&self) -> safetensors::tensor::Dtype {
|
||||||
|
self.dtype
|
||||||
|
}
|
||||||
|
|
||||||
|
fn shape(&self) -> &[usize] {
|
||||||
|
&self.shape
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data(&self) -> Cow<'_, [u8]> {
|
||||||
|
self.data.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data_len(&self) -> usize {
|
||||||
|
self.data.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> SafeArray<'a> {
|
||||||
|
fn from_ndarray<'b: 'a, T: STDtype, D: ndarray::Dimension + 'a>(
|
||||||
|
array: ndarray::ArrayView<'b, T, D>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let shape = array.shape().to_vec();
|
||||||
|
let dtype = T::dtype();
|
||||||
|
if array.ndim() == 0 {
|
||||||
|
return Err(SafeTensorError::InvalidTensorData(
|
||||||
|
core::any::type_name::<T>(),
|
||||||
|
"Cannot insert a scalar tensor".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !array.is_standard_layout() {
|
||||||
|
return Err(SafeTensorError::InvalidTensorData(
|
||||||
|
core::any::type_name::<T>(),
|
||||||
|
"ArrayView is not standard layout".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let data =
|
||||||
|
bytemuck::cast_slice(array.to_slice().ok_or(SafeTensorError::InvalidTensorData(
|
||||||
|
core::any::type_name::<T>(),
|
||||||
|
"ArrayView is not contiguous".to_string(),
|
||||||
|
))?);
|
||||||
|
let safe_array = SafeArray {
|
||||||
|
data: Cow::Borrowed(data),
|
||||||
|
shape,
|
||||||
|
dtype,
|
||||||
|
};
|
||||||
|
Ok(safe_array)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_safe_array_from_ndarray() {
|
||||||
|
use ndarray::Array2;
|
||||||
|
|
||||||
|
let array = Array2::<f32>::zeros((3, 4));
|
||||||
|
let safe_array = SafeArray::from_ndarray(array.view()).unwrap();
|
||||||
|
assert_eq!(safe_array.shape, vec![3, 4]);
|
||||||
|
assert_eq!(safe_array.dtype, safetensors::tensor::Dtype::F32);
|
||||||
|
assert_eq!(safe_array.data.len(), 3 * 4 * 4); // 3x4x4 bytes for f32
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_serialize_safe_arrays() {
|
||||||
|
use ndarray::{Array2, Array3};
|
||||||
|
|
||||||
|
let mut safe_arrays = SafeArrays::new();
|
||||||
|
let array = Array2::<f32>::zeros((3, 4));
|
||||||
|
let array2 = Array3::<u16>::zeros((8, 1, 9));
|
||||||
|
safe_arrays
|
||||||
|
.insert_ndarray("test_tensor", array.view())
|
||||||
|
.unwrap();
|
||||||
|
safe_arrays
|
||||||
|
.insert_ndarray("test_tensor2", array2.view())
|
||||||
|
.unwrap();
|
||||||
|
safe_arrays.insert_metadata("author", "example");
|
||||||
|
|
||||||
|
let serialized = safe_arrays.serialize().unwrap();
|
||||||
|
assert!(!serialized.is_empty());
|
||||||
|
|
||||||
|
// Deserialize to check if it works
|
||||||
|
let deserialized = SafeArraysView::from_bytes(&serialized).unwrap();
|
||||||
|
assert_eq!(deserialized.len(), 2);
|
||||||
|
assert_eq!(
|
||||||
|
deserialized
|
||||||
|
.tensor::<f32, ndarray::Ix2>("test_tensor")
|
||||||
|
.unwrap()
|
||||||
|
.shape(),
|
||||||
|
&[3, 4]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
deserialized
|
||||||
|
.tensor::<u16, ndarray::Ix3>("test_tensor2")
|
||||||
|
.unwrap()
|
||||||
|
.shape(),
|
||||||
|
&[8, 1, 9]
|
||||||
|
);
|
||||||
|
}
|
||||||
44
src/cli.rs
44
src/cli.rs
@@ -13,6 +13,12 @@ pub enum SubCommand {
|
|||||||
Detect(Detect),
|
Detect(Detect),
|
||||||
#[clap(name = "list")]
|
#[clap(name = "list")]
|
||||||
List(List),
|
List(List),
|
||||||
|
#[clap(name = "query")]
|
||||||
|
Query(Query),
|
||||||
|
#[clap(name = "similar")]
|
||||||
|
Similar(Similar),
|
||||||
|
#[clap(name = "stats")]
|
||||||
|
Stats(Stats),
|
||||||
#[clap(name = "completions")]
|
#[clap(name = "completions")]
|
||||||
Completions { shell: clap_complete::Shell },
|
Completions { shell: clap_complete::Shell },
|
||||||
}
|
}
|
||||||
@@ -58,12 +64,50 @@ pub struct Detect {
|
|||||||
pub nms_threshold: f32,
|
pub nms_threshold: f32,
|
||||||
#[clap(short, long, default_value_t = 8)]
|
#[clap(short, long, default_value_t = 8)]
|
||||||
pub batch_size: usize,
|
pub batch_size: usize,
|
||||||
|
#[clap(short = 'd', long)]
|
||||||
|
pub database: Option<PathBuf>,
|
||||||
|
#[clap(long, default_value = "facenet")]
|
||||||
|
pub model_name: String,
|
||||||
|
#[clap(long)]
|
||||||
|
pub save_to_db: bool,
|
||||||
pub image: PathBuf,
|
pub image: PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, clap::Args)]
|
#[derive(Debug, clap::Args)]
|
||||||
pub struct List {}
|
pub struct List {}
|
||||||
|
|
||||||
|
#[derive(Debug, clap::Args)]
|
||||||
|
pub struct Query {
|
||||||
|
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||||
|
pub database: PathBuf,
|
||||||
|
#[clap(short, long)]
|
||||||
|
pub image_id: Option<i64>,
|
||||||
|
#[clap(short, long)]
|
||||||
|
pub face_id: Option<i64>,
|
||||||
|
#[clap(long)]
|
||||||
|
pub show_embeddings: bool,
|
||||||
|
#[clap(long)]
|
||||||
|
pub show_landmarks: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, clap::Args)]
|
||||||
|
pub struct Similar {
|
||||||
|
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||||
|
pub database: PathBuf,
|
||||||
|
#[clap(short, long)]
|
||||||
|
pub face_id: i64,
|
||||||
|
#[clap(short, long, default_value_t = 0.7)]
|
||||||
|
pub threshold: f32,
|
||||||
|
#[clap(short, long, default_value_t = 10)]
|
||||||
|
pub limit: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, clap::Args)]
|
||||||
|
pub struct Stats {
|
||||||
|
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||||
|
pub database: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
impl Cli {
|
impl Cli {
|
||||||
pub fn completions(shell: clap_complete::Shell) {
|
pub fn completions(shell: clap_complete::Shell) {
|
||||||
let mut command = <Cli as clap::CommandFactory>::command();
|
let mut command = <Cli as clap::CommandFactory>::command();
|
||||||
|
|||||||
597
src/database.rs
Normal file
597
src/database.rs
Normal file
@@ -0,0 +1,597 @@
|
|||||||
|
use crate::errors::{Error, Result};
|
||||||
|
use crate::facedet::{FaceDetectionOutput, FaceLandmarks};
|
||||||
|
use bounding_box::Aabb2;
|
||||||
|
use error_stack::ResultExt;
|
||||||
|
use ndarray_math::CosineSimilarity;
|
||||||
|
use rusqlite::{Connection, OptionalExtension, params};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
/// Database connection and operations for face detection results
|
||||||
|
pub struct FaceDatabase {
|
||||||
|
conn: Connection,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents a stored image record
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ImageRecord {
|
||||||
|
pub id: i64,
|
||||||
|
pub file_path: String,
|
||||||
|
pub width: u32,
|
||||||
|
pub height: u32,
|
||||||
|
pub created_at: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents a stored face detection record
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct FaceRecord {
|
||||||
|
pub id: i64,
|
||||||
|
pub image_id: i64,
|
||||||
|
pub bbox_x1: f32,
|
||||||
|
pub bbox_y1: f32,
|
||||||
|
pub bbox_x2: f32,
|
||||||
|
pub bbox_y2: f32,
|
||||||
|
pub confidence: f32,
|
||||||
|
pub created_at: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents stored face landmarks
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct LandmarkRecord {
|
||||||
|
pub id: i64,
|
||||||
|
pub face_id: i64,
|
||||||
|
pub left_eye_x: f32,
|
||||||
|
pub left_eye_y: f32,
|
||||||
|
pub right_eye_x: f32,
|
||||||
|
pub right_eye_y: f32,
|
||||||
|
pub nose_x: f32,
|
||||||
|
pub nose_y: f32,
|
||||||
|
pub left_mouth_x: f32,
|
||||||
|
pub left_mouth_y: f32,
|
||||||
|
pub right_mouth_x: f32,
|
||||||
|
pub right_mouth_y: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents a stored face embedding
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct EmbeddingRecord {
|
||||||
|
pub id: i64,
|
||||||
|
pub face_id: i64,
|
||||||
|
pub embedding: ndarray::Array1<f32>,
|
||||||
|
pub model_name: String,
|
||||||
|
pub created_at: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FaceDatabase {
|
||||||
|
/// Create a new database connection and initialize tables
|
||||||
|
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
|
||||||
|
let conn = Connection::open(db_path).change_context(Error)?;
|
||||||
|
add_sqlite_cosine_similarity(&conn).change_context(Error)?;
|
||||||
|
let db = Self { conn };
|
||||||
|
db.create_tables()?;
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an in-memory database for testing
|
||||||
|
pub fn in_memory() -> Result<Self> {
|
||||||
|
let conn = Connection::open_in_memory().change_context(Error)?;
|
||||||
|
let db = Self { conn };
|
||||||
|
db.create_tables()?;
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create all necessary database tables
|
||||||
|
fn create_tables(&self) -> Result<()> {
|
||||||
|
// Images table
|
||||||
|
self.conn
|
||||||
|
.execute(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS images (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
file_path TEXT NOT NULL UNIQUE,
|
||||||
|
width INTEGER NOT NULL,
|
||||||
|
height INTEGER NOT NULL,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
// Faces table
|
||||||
|
self.conn
|
||||||
|
.execute(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS faces (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
image_id INTEGER NOT NULL,
|
||||||
|
bbox_x1 REAL NOT NULL,
|
||||||
|
bbox_y1 REAL NOT NULL,
|
||||||
|
bbox_x2 REAL NOT NULL,
|
||||||
|
bbox_y2 REAL NOT NULL,
|
||||||
|
confidence REAL NOT NULL,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
// Landmarks table
|
||||||
|
self.conn
|
||||||
|
.execute(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS landmarks (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
face_id INTEGER NOT NULL,
|
||||||
|
left_eye_x REAL NOT NULL,
|
||||||
|
left_eye_y REAL NOT NULL,
|
||||||
|
right_eye_x REAL NOT NULL,
|
||||||
|
right_eye_y REAL NOT NULL,
|
||||||
|
nose_x REAL NOT NULL,
|
||||||
|
nose_y REAL NOT NULL,
|
||||||
|
left_mouth_x REAL NOT NULL,
|
||||||
|
left_mouth_y REAL NOT NULL,
|
||||||
|
right_mouth_x REAL NOT NULL,
|
||||||
|
right_mouth_y REAL NOT NULL,
|
||||||
|
FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
// Embeddings table
|
||||||
|
self.conn
|
||||||
|
.execute(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS embeddings (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
face_id INTEGER NOT NULL,
|
||||||
|
embedding BLOB NOT NULL,
|
||||||
|
model_name TEXT NOT NULL,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
// Create indexes for better performance
|
||||||
|
self.conn
|
||||||
|
.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_faces_image_id ON faces (image_id)",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
self.conn
|
||||||
|
.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_landmarks_face_id ON landmarks (face_id)",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
self.conn
|
||||||
|
.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_embeddings_face_id ON embeddings (face_id)",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Store image metadata and return the image ID
|
||||||
|
pub fn store_image(&self, file_path: &str, width: u32, height: u32) -> Result<i64> {
|
||||||
|
let mut stmt = self
|
||||||
|
.conn
|
||||||
|
.prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)")
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
stmt.execute(params![file_path, width, height])
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
Ok(self.conn.last_insert_rowid())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Store face detection results
|
||||||
|
pub fn store_face_detections(
|
||||||
|
&self,
|
||||||
|
image_id: i64,
|
||||||
|
detection_output: &FaceDetectionOutput,
|
||||||
|
) -> Result<Vec<i64>> {
|
||||||
|
let mut face_ids = Vec::new();
|
||||||
|
|
||||||
|
for (i, bbox) in detection_output.bbox.iter().enumerate() {
|
||||||
|
let confidence = detection_output.confidence.get(i).copied().unwrap_or(0.0);
|
||||||
|
|
||||||
|
let face_id = self.store_face(image_id, bbox, confidence)?;
|
||||||
|
face_ids.push(face_id);
|
||||||
|
|
||||||
|
// Store landmarks if available
|
||||||
|
if let Some(landmarks) = detection_output.landmark.get(i) {
|
||||||
|
self.store_landmarks(face_id, landmarks)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(face_ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Store a single face detection
|
||||||
|
pub fn store_face(&self, image_id: i64, bbox: &Aabb2<usize>, confidence: f32) -> Result<i64> {
|
||||||
|
let mut stmt = self
|
||||||
|
.conn
|
||||||
|
.prepare(
|
||||||
|
r#"
|
||||||
|
INSERT INTO faces (image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence)
|
||||||
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
stmt.execute(params![
|
||||||
|
image_id,
|
||||||
|
bbox.x1() as f32,
|
||||||
|
bbox.y1() as f32,
|
||||||
|
bbox.x2() as f32,
|
||||||
|
bbox.y2() as f32,
|
||||||
|
confidence
|
||||||
|
])
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
Ok(self.conn.last_insert_rowid())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Store face landmarks
|
||||||
|
pub fn store_landmarks(&self, face_id: i64, landmarks: &FaceLandmarks) -> Result<i64> {
|
||||||
|
let mut stmt = self
|
||||||
|
.conn
|
||||||
|
.prepare(
|
||||||
|
r#"
|
||||||
|
INSERT INTO landmarks
|
||||||
|
(face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y,
|
||||||
|
nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y)
|
||||||
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
stmt.execute(params![
|
||||||
|
face_id,
|
||||||
|
landmarks.left_eye.x,
|
||||||
|
landmarks.left_eye.y,
|
||||||
|
landmarks.right_eye.x,
|
||||||
|
landmarks.right_eye.y,
|
||||||
|
landmarks.nose.x,
|
||||||
|
landmarks.nose.y,
|
||||||
|
landmarks.left_mouth.x,
|
||||||
|
landmarks.left_mouth.y,
|
||||||
|
landmarks.right_mouth.x,
|
||||||
|
landmarks.right_mouth.y,
|
||||||
|
])
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
Ok(self.conn.last_insert_rowid())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Store face embeddings
|
||||||
|
pub fn store_embeddings(
|
||||||
|
&self,
|
||||||
|
face_ids: &[i64],
|
||||||
|
embeddings: &[ndarray::Array2<f32>],
|
||||||
|
model_name: &str,
|
||||||
|
) -> Result<Vec<i64>> {
|
||||||
|
let mut embedding_ids = Vec::new();
|
||||||
|
|
||||||
|
for (face_idx, embedding_batch) in embeddings.iter().enumerate() {
|
||||||
|
for (batch_idx, embedding_row) in embedding_batch.rows().into_iter().enumerate() {
|
||||||
|
let global_idx = face_idx * embedding_batch.nrows() + batch_idx;
|
||||||
|
|
||||||
|
if global_idx >= face_ids.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let face_id = face_ids[global_idx];
|
||||||
|
let embedding_id =
|
||||||
|
self.store_single_embedding(face_id, embedding_row, model_name)?;
|
||||||
|
embedding_ids.push(embedding_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(embedding_ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Store a single embedding
|
||||||
|
pub fn store_single_embedding(
|
||||||
|
&self,
|
||||||
|
face_id: i64,
|
||||||
|
embedding: ndarray::ArrayView1<f32>,
|
||||||
|
model_name: &str,
|
||||||
|
) -> Result<i64> {
|
||||||
|
let embedding_bytes =
|
||||||
|
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)])
|
||||||
|
.change_context(Error)?
|
||||||
|
.serialize()
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let mut stmt = self
|
||||||
|
.conn
|
||||||
|
.prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)")
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
stmt.execute(params![face_id, embedding_bytes, model_name])
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
Ok(self.conn.last_insert_rowid())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get image by ID
|
||||||
|
pub fn get_image(&self, image_id: i64) -> Result<Option<ImageRecord>> {
|
||||||
|
let mut stmt = self
|
||||||
|
.conn
|
||||||
|
.prepare("SELECT id, file_path, width, height, created_at FROM images WHERE id = ?1")
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let result = stmt
|
||||||
|
.query_row(params![image_id], |row| {
|
||||||
|
Ok(ImageRecord {
|
||||||
|
id: row.get(0)?,
|
||||||
|
file_path: row.get(1)?,
|
||||||
|
width: row.get(2)?,
|
||||||
|
height: row.get(3)?,
|
||||||
|
created_at: row.get(4)?,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.optional()
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get all faces for an image
|
||||||
|
pub fn get_faces_for_image(&self, image_id: i64) -> Result<Vec<FaceRecord>> {
|
||||||
|
let mut stmt = self
|
||||||
|
.conn
|
||||||
|
.prepare(
|
||||||
|
r#"
|
||||||
|
SELECT id, image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence, created_at
|
||||||
|
FROM faces WHERE image_id = ?1
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let face_iter = stmt
|
||||||
|
.query_map(params![image_id], |row| {
|
||||||
|
Ok(FaceRecord {
|
||||||
|
id: row.get(0)?,
|
||||||
|
image_id: row.get(1)?,
|
||||||
|
bbox_x1: row.get(2)?,
|
||||||
|
bbox_y1: row.get(3)?,
|
||||||
|
bbox_x2: row.get(4)?,
|
||||||
|
bbox_y2: row.get(5)?,
|
||||||
|
confidence: row.get(6)?,
|
||||||
|
created_at: row.get(7)?,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let mut faces = Vec::new();
|
||||||
|
for face in face_iter {
|
||||||
|
faces.push(face.change_context(Error)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(faces)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get landmarks for a face
|
||||||
|
pub fn get_landmarks(&self, face_id: i64) -> Result<Option<LandmarkRecord>> {
|
||||||
|
let mut stmt = self
|
||||||
|
.conn
|
||||||
|
.prepare(
|
||||||
|
r#"
|
||||||
|
SELECT id, face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y,
|
||||||
|
nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y
|
||||||
|
FROM landmarks WHERE face_id = ?1
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let result = stmt
|
||||||
|
.query_row(params![face_id], |row| {
|
||||||
|
Ok(LandmarkRecord {
|
||||||
|
id: row.get(0)?,
|
||||||
|
face_id: row.get(1)?,
|
||||||
|
left_eye_x: row.get(2)?,
|
||||||
|
left_eye_y: row.get(3)?,
|
||||||
|
right_eye_x: row.get(4)?,
|
||||||
|
right_eye_y: row.get(5)?,
|
||||||
|
nose_x: row.get(6)?,
|
||||||
|
nose_y: row.get(7)?,
|
||||||
|
left_mouth_x: row.get(8)?,
|
||||||
|
left_mouth_y: row.get(9)?,
|
||||||
|
right_mouth_x: row.get(10)?,
|
||||||
|
right_mouth_y: row.get(11)?,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.optional()
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get embeddings for a face
|
||||||
|
pub fn get_embeddings(&self, face_id: i64) -> Result<Vec<EmbeddingRecord>> {
|
||||||
|
let mut stmt = self
|
||||||
|
.conn
|
||||||
|
.prepare(
|
||||||
|
"SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE face_id = ?1",
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let embedding_iter = stmt
|
||||||
|
.query_map(params![face_id], |row| {
|
||||||
|
let embedding_bytes: Vec<u8> = row.get(2)?;
|
||||||
|
let embedding: ndarray::Array1<f32> = {
|
||||||
|
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
|
||||||
|
.change_context(Error)
|
||||||
|
// .change_context(Error)?
|
||||||
|
.unwrap();
|
||||||
|
sf.tensor::<f32, ndarray::Ix1>("embedding")
|
||||||
|
// .change_context(Error)?
|
||||||
|
.unwrap()
|
||||||
|
.to_owned()
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(EmbeddingRecord {
|
||||||
|
id: row.get(0)?,
|
||||||
|
face_id: row.get(1)?,
|
||||||
|
embedding,
|
||||||
|
model_name: row.get(3)?,
|
||||||
|
created_at: row.get(4)?,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let mut embeddings = Vec::new();
|
||||||
|
for embedding in embedding_iter {
|
||||||
|
embeddings.push(embedding.change_context(Error)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(embeddings)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get database statistics
|
||||||
|
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
|
||||||
|
let images: usize = self
|
||||||
|
.conn
|
||||||
|
.query_row("SELECT COUNT(*) FROM images", [], |row| row.get(0))
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let faces: usize = self
|
||||||
|
.conn
|
||||||
|
.query_row("SELECT COUNT(*) FROM faces", [], |row| row.get(0))
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let landmarks: usize = self
|
||||||
|
.conn
|
||||||
|
.query_row("SELECT COUNT(*) FROM landmarks", [], |row| row.get(0))
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let embeddings: usize = self
|
||||||
|
.conn
|
||||||
|
.query_row("SELECT COUNT(*) FROM embeddings", [], |row| row.get(0))
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
Ok((images, faces, landmarks, embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find similar faces based on cosine similarity of embeddings
|
||||||
|
/// Return ids and similarity scores of similar faces
|
||||||
|
pub fn find_similar_faces(
|
||||||
|
&self,
|
||||||
|
embedding: &ndarray::Array1<f32>,
|
||||||
|
threshold: f32,
|
||||||
|
limit: usize,
|
||||||
|
) -> Result<Vec<(i64, f32)>> {
|
||||||
|
// Serialize the query embedding to bytes
|
||||||
|
let embedding_bytes =
|
||||||
|
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
|
||||||
|
.change_context(Error)?
|
||||||
|
.serialize()
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let mut stmt = self
|
||||||
|
.conn
|
||||||
|
.prepare(
|
||||||
|
r#" SELECT face_id, cosine_similarity(?1, embedding) as similarity
|
||||||
|
FROM embeddings
|
||||||
|
WHERE cosine_similarity(?1, embedding) >= ?2
|
||||||
|
ORDER BY similarity DESC
|
||||||
|
LIMIT ?3"#,
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let result = stmt
|
||||||
|
.query_map(params![embedding_bytes, threshold, limit], |row| {
|
||||||
|
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
|
||||||
|
})
|
||||||
|
.change_context(Error)?
|
||||||
|
.map(|r| r.change_context(Error))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
|
// let mut results = Vec::new();
|
||||||
|
// for result in result_iter {
|
||||||
|
// results.push(result.change_context(Error)?);
|
||||||
|
// }
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
|
||||||
|
use rusqlite::functions::*;
|
||||||
|
db.create_scalar_function(
|
||||||
|
"cosine_similarity",
|
||||||
|
2,
|
||||||
|
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
|
||||||
|
move |ctx| {
|
||||||
|
if ctx.len() != 2 {
|
||||||
|
return Err(rusqlite::Error::UserFunctionError(
|
||||||
|
"cosine_similarity requires exactly 2 arguments".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let array_1 = ctx.get_raw(0).as_blob()?;
|
||||||
|
let array_2 = ctx.get_raw(1).as_blob()?;
|
||||||
|
|
||||||
|
let array_1_st = ndarray_safetensors::SafeArraysView::from_bytes(array_1)
|
||||||
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||||
|
let array_2_st = ndarray_safetensors::SafeArraysView::from_bytes(array_2)
|
||||||
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||||
|
|
||||||
|
let array_view_1 = array_1_st
|
||||||
|
.tensor::<f32, ndarray::Ix1>("embedding")
|
||||||
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||||
|
let array_view_2 = array_2_st
|
||||||
|
.tensor::<f32, ndarray::Ix1>("embedding")
|
||||||
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||||
|
|
||||||
|
let similarity = array_view_1
|
||||||
|
.cosine_similarity(array_view_2)
|
||||||
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||||
|
|
||||||
|
Ok(similarity)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.change_context(Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_database_creation() -> Result<()> {
|
||||||
|
let db = FaceDatabase::in_memory()?;
|
||||||
|
let (images, faces, landmarks, embeddings) = db.get_stats()?;
|
||||||
|
assert_eq!(images, 0);
|
||||||
|
assert_eq!(faces, 0);
|
||||||
|
assert_eq!(landmarks, 0);
|
||||||
|
assert_eq!(embeddings, 0);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_store_and_retrieve_image() -> Result<()> {
|
||||||
|
let db = FaceDatabase::in_memory()?;
|
||||||
|
let image_id = db.store_image("/path/to/image.jpg", 800, 600)?;
|
||||||
|
|
||||||
|
let image = db.get_image(image_id)?.unwrap();
|
||||||
|
assert_eq!(image.file_path, "/path/to/image.jpg");
|
||||||
|
assert_eq!(image.width, 800);
|
||||||
|
assert_eq!(image.height, 600);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
pub mod database;
|
||||||
pub mod errors;
|
pub mod errors;
|
||||||
pub mod facedet;
|
pub mod facedet;
|
||||||
pub mod faceembed;
|
pub mod faceembed;
|
||||||
|
|||||||
197
src/main.rs
197
src/main.rs
@@ -1,9 +1,10 @@
|
|||||||
mod cli;
|
mod cli;
|
||||||
mod errors;
|
mod errors;
|
||||||
use bounding_box::roi::MultiRoi;
|
use bounding_box::roi::MultiRoi;
|
||||||
use detector::{facedet, facedet::FaceDetectionConfig, faceembed};
|
use detector::{database::FaceDatabase, facedet, facedet::FaceDetectionConfig, faceembed};
|
||||||
use errors::*;
|
use errors::*;
|
||||||
use fast_image_resize::ResizeOptions;
|
use fast_image_resize::ResizeOptions;
|
||||||
|
|
||||||
use ndarray::*;
|
use ndarray::*;
|
||||||
use ndarray_image::*;
|
use ndarray_image::*;
|
||||||
use ndarray_resize::NdFir;
|
use ndarray_resize::NdFir;
|
||||||
@@ -13,7 +14,7 @@ const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx")
|
|||||||
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
|
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
|
||||||
pub fn main() -> Result<()> {
|
pub fn main() -> Result<()> {
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
.with_env_filter("error")
|
.with_env_filter("info")
|
||||||
.with_thread_ids(true)
|
.with_thread_ids(true)
|
||||||
.with_thread_names(true)
|
.with_thread_names(true)
|
||||||
.with_target(false)
|
.with_target(false)
|
||||||
@@ -77,6 +78,15 @@ pub fn main() -> Result<()> {
|
|||||||
cli::SubCommand::List(list) => {
|
cli::SubCommand::List(list) => {
|
||||||
println!("List: {:?}", list);
|
println!("List: {:?}", list);
|
||||||
}
|
}
|
||||||
|
cli::SubCommand::Query(query) => {
|
||||||
|
run_query(query)?;
|
||||||
|
}
|
||||||
|
cli::SubCommand::Similar(similar) => {
|
||||||
|
run_similar(similar)?;
|
||||||
|
}
|
||||||
|
cli::SubCommand::Stats(stats) => {
|
||||||
|
run_stats(stats)?;
|
||||||
|
}
|
||||||
cli::SubCommand::Completions { shell } => {
|
cli::SubCommand::Completions { shell } => {
|
||||||
cli::Cli::completions(shell);
|
cli::Cli::completions(shell);
|
||||||
}
|
}
|
||||||
@@ -89,10 +99,22 @@ where
|
|||||||
D: facedet::FaceDetector,
|
D: facedet::FaceDetector,
|
||||||
E: faceembed::FaceEmbedder,
|
E: faceembed::FaceEmbedder,
|
||||||
{
|
{
|
||||||
|
// Initialize database if requested
|
||||||
|
let db = if detect.save_to_db {
|
||||||
|
let db_path = detect
|
||||||
|
.database
|
||||||
|
.as_ref()
|
||||||
|
.map(|p| p.as_path())
|
||||||
|
.unwrap_or_else(|| std::path::Path::new("face_detections.db"));
|
||||||
|
Some(FaceDatabase::new(db_path).change_context(Error)?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
let image = image::open(&detect.image)
|
let image = image::open(&detect.image)
|
||||||
.change_context(Error)
|
.change_context(Error)
|
||||||
.attach_printable(detect.image.to_string_lossy().to_string())?;
|
.attach_printable(detect.image.to_string_lossy().to_string())?;
|
||||||
let image = image.into_rgb8();
|
let image = image.into_rgb8();
|
||||||
|
let (image_width, image_height) = image.dimensions();
|
||||||
let mut array = image
|
let mut array = image
|
||||||
.into_ndarray()
|
.into_ndarray()
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
@@ -106,6 +128,26 @@ where
|
|||||||
)
|
)
|
||||||
.change_context(errors::Error)
|
.change_context(errors::Error)
|
||||||
.attach_printable("Failed to detect faces")?;
|
.attach_printable("Failed to detect faces")?;
|
||||||
|
|
||||||
|
// Store image and face detections in database if requested
|
||||||
|
let (image_id, face_ids) = if let Some(ref database) = db {
|
||||||
|
let image_path = detect.image.to_string_lossy();
|
||||||
|
let img_id = database
|
||||||
|
.store_image(&image_path, image_width, image_height)
|
||||||
|
.change_context(Error)?;
|
||||||
|
let face_ids = database
|
||||||
|
.store_face_detections(img_id, &output)
|
||||||
|
.change_context(Error)?;
|
||||||
|
tracing::info!(
|
||||||
|
"Stored image {} with {} faces in database",
|
||||||
|
img_id,
|
||||||
|
face_ids.len()
|
||||||
|
);
|
||||||
|
(Some(img_id), Some(face_ids))
|
||||||
|
} else {
|
||||||
|
(None, None)
|
||||||
|
};
|
||||||
|
|
||||||
for bbox in &output.bbox {
|
for bbox in &output.bbox {
|
||||||
tracing::info!("Detected face: {:?}", bbox);
|
tracing::info!("Detected face: {:?}", bbox);
|
||||||
use bounding_box::draw::*;
|
use bounding_box::draw::*;
|
||||||
@@ -159,6 +201,25 @@ where
|
|||||||
})
|
})
|
||||||
.collect::<Result<Vec<Array2<f32>>>>()?;
|
.collect::<Result<Vec<Array2<f32>>>>()?;
|
||||||
|
|
||||||
|
// Store embeddings in database if requested
|
||||||
|
if let (Some(database), Some(face_ids)) = (&db, &face_ids) {
|
||||||
|
let embedding_ids = database
|
||||||
|
.store_embeddings(face_ids, &embeddings, &detect.model_name)
|
||||||
|
.change_context(Error)?;
|
||||||
|
tracing::info!("Stored {} embeddings in database", embedding_ids.len());
|
||||||
|
|
||||||
|
// Print database statistics
|
||||||
|
let (num_images, num_faces, num_landmarks, num_embeddings) =
|
||||||
|
database.get_stats().change_context(Error)?;
|
||||||
|
tracing::info!(
|
||||||
|
"Database stats - Images: {}, Faces: {}, Landmarks: {}, Embeddings: {}",
|
||||||
|
num_images,
|
||||||
|
num_faces,
|
||||||
|
num_landmarks,
|
||||||
|
num_embeddings
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let v = array.view();
|
let v = array.view();
|
||||||
if let Some(output) = detect.output {
|
if let Some(output) = detect.output {
|
||||||
let image: image::RgbImage = v
|
let image: image::RgbImage = v
|
||||||
@@ -173,3 +234,135 @@ where
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn run_query(query: cli::Query) -> Result<()> {
|
||||||
|
let db = FaceDatabase::new(&query.database).change_context(Error)?;
|
||||||
|
|
||||||
|
if let Some(image_id) = query.image_id {
|
||||||
|
if let Some(image) = db.get_image(image_id).change_context(Error)? {
|
||||||
|
println!("Image: {}", image.file_path);
|
||||||
|
println!("Dimensions: {}x{}", image.width, image.height);
|
||||||
|
println!("Created: {}", image.created_at);
|
||||||
|
|
||||||
|
let faces = db.get_faces_for_image(image_id).change_context(Error)?;
|
||||||
|
println!("Faces found: {}", faces.len());
|
||||||
|
|
||||||
|
for face in faces {
|
||||||
|
println!(
|
||||||
|
" Face ID {}: bbox({:.1}, {:.1}, {:.1}, {:.1}), confidence: {:.3}",
|
||||||
|
face.id,
|
||||||
|
face.bbox_x1,
|
||||||
|
face.bbox_y1,
|
||||||
|
face.bbox_x2,
|
||||||
|
face.bbox_y2,
|
||||||
|
face.confidence
|
||||||
|
);
|
||||||
|
|
||||||
|
if query.show_landmarks {
|
||||||
|
if let Some(landmarks) = db.get_landmarks(face.id).change_context(Error)? {
|
||||||
|
println!(
|
||||||
|
" Landmarks: left_eye({:.1}, {:.1}), right_eye({:.1}, {:.1}), nose({:.1}, {:.1})",
|
||||||
|
landmarks.left_eye_x,
|
||||||
|
landmarks.left_eye_y,
|
||||||
|
landmarks.right_eye_x,
|
||||||
|
landmarks.right_eye_y,
|
||||||
|
landmarks.nose_x,
|
||||||
|
landmarks.nose_y
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if query.show_embeddings {
|
||||||
|
let embeddings = db.get_embeddings(face.id).change_context(Error)?;
|
||||||
|
for embedding in embeddings {
|
||||||
|
println!(
|
||||||
|
" Embedding ({}): {} dims, model: {}",
|
||||||
|
embedding.id,
|
||||||
|
embedding.embedding.len(),
|
||||||
|
embedding.model_name
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
println!("Image with ID {} not found", image_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(face_id) = query.face_id {
|
||||||
|
if let Some(landmarks) = db.get_landmarks(face_id).change_context(Error)? {
|
||||||
|
println!(
|
||||||
|
"Landmarks for face {}: left_eye({:.1}, {:.1}), right_eye({:.1}, {:.1}), nose({:.1}, {:.1})",
|
||||||
|
face_id,
|
||||||
|
landmarks.left_eye_x,
|
||||||
|
landmarks.left_eye_y,
|
||||||
|
landmarks.right_eye_x,
|
||||||
|
landmarks.right_eye_y,
|
||||||
|
landmarks.nose_x,
|
||||||
|
landmarks.nose_y
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
println!("No landmarks found for face {}", face_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
let embeddings = db.get_embeddings(face_id).change_context(Error)?;
|
||||||
|
println!(
|
||||||
|
"Embeddings for face {}: {} found",
|
||||||
|
face_id,
|
||||||
|
embeddings.len()
|
||||||
|
);
|
||||||
|
for embedding in embeddings {
|
||||||
|
println!(
|
||||||
|
" Embedding {}: {} dims, model: {}, created: {}",
|
||||||
|
embedding.id,
|
||||||
|
embedding.embedding.len(),
|
||||||
|
embedding.model_name,
|
||||||
|
embedding.created_at
|
||||||
|
);
|
||||||
|
// if query.show_embeddings {
|
||||||
|
// println!(" Values: {:?}", &embedding.embedding);
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_similar(similar: cli::Similar) -> Result<()> {
|
||||||
|
let db = FaceDatabase::new(&similar.database).change_context(Error)?;
|
||||||
|
|
||||||
|
let embeddings = db.get_embeddings(similar.face_id).change_context(Error)?;
|
||||||
|
if embeddings.is_empty() {
|
||||||
|
println!("No embeddings found for face {}", similar.face_id);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let query_embedding = &embeddings[0].embedding;
|
||||||
|
let similar_faces = db
|
||||||
|
.find_similar_faces(query_embedding, similar.threshold, similar.limit)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"Found {} similar faces (threshold: {:.3}):",
|
||||||
|
similar_faces.len(),
|
||||||
|
similar.threshold
|
||||||
|
);
|
||||||
|
for (face_id, similarity) in similar_faces {
|
||||||
|
println!(" Face {}: similarity {:.3}", face_id, similarity);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_stats(stats: cli::Stats) -> Result<()> {
|
||||||
|
let db = FaceDatabase::new(&stats.database).change_context(Error)?;
|
||||||
|
let (images, faces, landmarks, embeddings) = db.get_stats().change_context(Error)?;
|
||||||
|
|
||||||
|
println!("Database Statistics:");
|
||||||
|
println!(" Images: {}", images);
|
||||||
|
println!(" Faces: {}", faces);
|
||||||
|
println!(" Landmarks: {}", landmarks);
|
||||||
|
println!(" Embeddings: {}", embeddings);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user