feat: save safetensors to the database
This commit is contained in:
@@ -37,7 +37,7 @@ nalgebra = { workspace = true }
|
|||||||
ndarray = "0.16.1"
|
ndarray = "0.16.1"
|
||||||
ndarray-image = { workspace = true }
|
ndarray-image = { workspace = true }
|
||||||
ndarray-resize = { workspace = true }
|
ndarray-resize = { workspace = true }
|
||||||
rusqlite = { version = "0.37.0", features = ["modern-full"] }
|
rusqlite = { version = "0.37.0", features = ["functions", "modern-full"] }
|
||||||
tap = "1.0.1"
|
tap = "1.0.1"
|
||||||
thiserror = "2.0"
|
thiserror = "2.0"
|
||||||
tokio = "1.43.1"
|
tokio = "1.43.1"
|
||||||
|
|||||||
@@ -114,16 +114,17 @@
|
|||||||
stdenv = p: p.clangStdenv;
|
stdenv = p: p.clangStdenv;
|
||||||
doCheck = false;
|
doCheck = false;
|
||||||
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
|
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
|
||||||
ORT_LIB_LOCATION = "${patchedOnnxruntime}";
|
# ORT_LIB_LOCATION = "${patchedOnnxruntime}";
|
||||||
ORT_ENV_SYSTEM_LIB_LOCATION = "${patchedOnnxruntime}/lib";
|
# ORT_ENV_SYSTEM_LIB_LOCATION = "${patchedOnnxruntime}/lib";
|
||||||
ORT_ENV_PREFER_DYNAMIC_LINK = true;
|
# ORT_ENV_PREFER_DYNAMIC_LINK = true;
|
||||||
nativeBuildInputs = with pkgs; [
|
nativeBuildInputs = with pkgs; [
|
||||||
cmake
|
cmake
|
||||||
pkg-config
|
pkg-config
|
||||||
];
|
];
|
||||||
buildInputs = with pkgs;
|
buildInputs = with pkgs;
|
||||||
[
|
[
|
||||||
# onnxruntime
|
patchedOnnxruntime
|
||||||
|
sqlite
|
||||||
]
|
]
|
||||||
++ (lib.optionals pkgs.stdenv.isDarwin [
|
++ (lib.optionals pkgs.stdenv.isDarwin [
|
||||||
libiconv
|
libiconv
|
||||||
|
|||||||
4
justfile
4
justfile
@@ -9,5 +9,5 @@ open:
|
|||||||
bench:
|
bench:
|
||||||
cargo build --release
|
cargo build --release
|
||||||
BINARY="" hyperfine --warmup 3 --export-markdown benchmark.md \
|
BINARY="" hyperfine --warmup 3 --export-markdown benchmark.md \
|
||||||
"$CARGO_TARGET_DIR/release/detector detect -f coreml selfie.jpg" \
|
"$CARGO_TARGET_DIR/release/detector detect -f cpu selfie.jpg" \
|
||||||
"$CARGO_TARGET_DIR/release/detector detect -f coreml -b 16 selfie.jpg"
|
"$CARGO_TARGET_DIR/release/detector detect -f cpu -b 1 selfie.jpg"
|
||||||
|
|||||||
@@ -68,17 +68,17 @@ use safetensors::tensor::SafeTensors;
|
|||||||
/// let view = SafeArrayView::from_bytes(&bytes).unwrap();
|
/// let view = SafeArrayView::from_bytes(&bytes).unwrap();
|
||||||
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
|
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
|
||||||
/// ```
|
/// ```
|
||||||
pub struct SafeArrayView<'a> {
|
pub struct SafeArraysView<'a> {
|
||||||
pub tensors: SafeTensors<'a>,
|
pub tensors: SafeTensors<'a>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> SafeArrayView<'a> {
|
impl<'a> SafeArraysView<'a> {
|
||||||
fn new(tensors: SafeTensors<'a>) -> Self {
|
fn new(tensors: SafeTensors<'a>) -> Self {
|
||||||
Self { tensors }
|
Self { tensors }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a SafeArrayView from serialized bytes
|
/// Create a SafeArrayView from serialized bytes
|
||||||
pub fn from_bytes(bytes: &'a [u8]) -> Result<SafeArrayView<'a>> {
|
pub fn from_bytes(bytes: &'a [u8]) -> Result<SafeArraysView<'a>> {
|
||||||
let tensors = SafeTensors::deserialize(bytes)?;
|
let tensors = SafeTensors::deserialize(bytes)?;
|
||||||
Ok(Self::new(tensors))
|
Ok(Self::new(tensors))
|
||||||
}
|
}
|
||||||
@@ -413,7 +413,7 @@ fn test_serialize_safe_arrays() {
|
|||||||
assert!(!serialized.is_empty());
|
assert!(!serialized.is_empty());
|
||||||
|
|
||||||
// Deserialize to check if it works
|
// Deserialize to check if it works
|
||||||
let deserialized = SafeArrayView::from_bytes(&serialized).unwrap();
|
let deserialized = SafeArraysView::from_bytes(&serialized).unwrap();
|
||||||
assert_eq!(deserialized.len(), 2);
|
assert_eq!(deserialized.len(), 2);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
deserialized
|
deserialized
|
||||||
|
|||||||
137
src/database.rs
137
src/database.rs
@@ -2,6 +2,7 @@ use crate::errors::{Error, Result};
|
|||||||
use crate::facedet::{FaceDetectionOutput, FaceLandmarks};
|
use crate::facedet::{FaceDetectionOutput, FaceLandmarks};
|
||||||
use bounding_box::Aabb2;
|
use bounding_box::Aabb2;
|
||||||
use error_stack::ResultExt;
|
use error_stack::ResultExt;
|
||||||
|
use ndarray_math::CosineSimilarity;
|
||||||
use rusqlite::{Connection, OptionalExtension, params};
|
use rusqlite::{Connection, OptionalExtension, params};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
@@ -55,7 +56,7 @@ pub struct LandmarkRecord {
|
|||||||
pub struct EmbeddingRecord {
|
pub struct EmbeddingRecord {
|
||||||
pub id: i64,
|
pub id: i64,
|
||||||
pub face_id: i64,
|
pub face_id: i64,
|
||||||
pub embedding: Vec<f32>,
|
pub embedding: ndarray::Array1<f32>,
|
||||||
pub model_name: String,
|
pub model_name: String,
|
||||||
pub created_at: String,
|
pub created_at: String,
|
||||||
}
|
}
|
||||||
@@ -64,6 +65,7 @@ impl FaceDatabase {
|
|||||||
/// Create a new database connection and initialize tables
|
/// Create a new database connection and initialize tables
|
||||||
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
|
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
|
||||||
let conn = Connection::open(db_path).change_context(Error)?;
|
let conn = Connection::open(db_path).change_context(Error)?;
|
||||||
|
add_sqlite_cosine_similarity(&conn).change_context(Error)?;
|
||||||
let db = Self { conn };
|
let db = Self { conn };
|
||||||
db.create_tables()?;
|
db.create_tables()?;
|
||||||
Ok(db)
|
Ok(db)
|
||||||
@@ -431,10 +433,16 @@ impl FaceDatabase {
|
|||||||
let embedding_iter = stmt
|
let embedding_iter = stmt
|
||||||
.query_map(params![face_id], |row| {
|
.query_map(params![face_id], |row| {
|
||||||
let embedding_bytes: Vec<u8> = row.get(2)?;
|
let embedding_bytes: Vec<u8> = row.get(2)?;
|
||||||
let embedding: Vec<f32> = embedding_bytes
|
let embedding: ndarray::Array1<f32> = {
|
||||||
.chunks_exact(4)
|
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
|
||||||
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
|
.change_context(Error)
|
||||||
.collect();
|
// .change_context(Error)?
|
||||||
|
.unwrap();
|
||||||
|
sf.tensor::<f32, ndarray::Ix1>("embedding")
|
||||||
|
// .change_context(Error)?
|
||||||
|
.unwrap()
|
||||||
|
.to_owned()
|
||||||
|
};
|
||||||
|
|
||||||
Ok(EmbeddingRecord {
|
Ok(EmbeddingRecord {
|
||||||
id: row.get(0)?,
|
id: row.get(0)?,
|
||||||
@@ -454,46 +462,6 @@ impl FaceDatabase {
|
|||||||
Ok(embeddings)
|
Ok(embeddings)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Find similar faces by embedding (using cosine similarity)
|
|
||||||
pub fn find_similar_faces(
|
|
||||||
&self,
|
|
||||||
query_embedding: &[f32],
|
|
||||||
threshold: f32,
|
|
||||||
limit: usize,
|
|
||||||
) -> Result<Vec<(i64, f32)>> {
|
|
||||||
let mut stmt = self
|
|
||||||
.conn
|
|
||||||
.prepare("SELECT face_id, embedding FROM embeddings")
|
|
||||||
.change_context(Error)?;
|
|
||||||
|
|
||||||
let embedding_iter = stmt
|
|
||||||
.query_map([], |row| {
|
|
||||||
let face_id: i64 = row.get(0)?;
|
|
||||||
let embedding_bytes: Vec<u8> = row.get(1)?;
|
|
||||||
let embedding: Vec<f32> = embedding_bytes
|
|
||||||
.chunks_exact(4)
|
|
||||||
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
|
|
||||||
.collect();
|
|
||||||
Ok((face_id, embedding))
|
|
||||||
})
|
|
||||||
.change_context(Error)?;
|
|
||||||
|
|
||||||
let mut similarities = Vec::new();
|
|
||||||
for result in embedding_iter {
|
|
||||||
let (face_id, embedding) = result.change_context(Error)?;
|
|
||||||
let similarity = cosine_similarity(query_embedding, &embedding);
|
|
||||||
if similarity >= threshold {
|
|
||||||
similarities.push((face_id, similarity));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort by similarity (descending) and limit results
|
|
||||||
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
|
||||||
similarities.truncate(limit);
|
|
||||||
|
|
||||||
Ok(similarities)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get database statistics
|
/// Get database statistics
|
||||||
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
|
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
|
||||||
let images: usize = self
|
let images: usize = self
|
||||||
@@ -518,6 +486,85 @@ impl FaceDatabase {
|
|||||||
|
|
||||||
Ok((images, faces, landmarks, embeddings))
|
Ok((images, faces, landmarks, embeddings))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Find similar faces based on cosine similarity of embeddings
|
||||||
|
/// Return ids and similarity scores of similar faces
|
||||||
|
pub fn find_similar_faces(
|
||||||
|
&self,
|
||||||
|
embedding: &ndarray::Array1<f32>,
|
||||||
|
threshold: f32,
|
||||||
|
limit: usize,
|
||||||
|
) -> Result<Vec<(i64, f32)>> {
|
||||||
|
// Serialize the query embedding to bytes
|
||||||
|
let embedding_bytes =
|
||||||
|
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
|
||||||
|
.change_context(Error)?
|
||||||
|
.serialize()
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let mut stmt = self
|
||||||
|
.conn
|
||||||
|
.prepare(
|
||||||
|
r#" SELECT face_id, cosine_similarity(?1, embedding) as similarity
|
||||||
|
FROM embeddings
|
||||||
|
WHERE cosine_similarity(?1, embedding) >= ?2
|
||||||
|
ORDER BY similarity DESC
|
||||||
|
LIMIT ?3"#,
|
||||||
|
)
|
||||||
|
.change_context(Error)?;
|
||||||
|
|
||||||
|
let result = stmt
|
||||||
|
.query_map(params![embedding_bytes, threshold, limit], |row| {
|
||||||
|
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
|
||||||
|
})
|
||||||
|
.change_context(Error)?
|
||||||
|
.map(|r| r.change_context(Error))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
|
// let mut results = Vec::new();
|
||||||
|
// for result in result_iter {
|
||||||
|
// results.push(result.change_context(Error)?);
|
||||||
|
// }
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
|
||||||
|
use rusqlite::functions::*;
|
||||||
|
db.create_scalar_function(
|
||||||
|
"cosine_similarity",
|
||||||
|
2,
|
||||||
|
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
|
||||||
|
move |ctx| {
|
||||||
|
if ctx.len() != 2 {
|
||||||
|
return Err(rusqlite::Error::UserFunctionError(
|
||||||
|
"cosine_similarity requires exactly 2 arguments".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let array_1 = ctx.get_raw(0).as_blob()?;
|
||||||
|
let array_2 = ctx.get_raw(1).as_blob()?;
|
||||||
|
|
||||||
|
let array_1_st = ndarray_safetensors::SafeArraysView::from_bytes(array_1)
|
||||||
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||||
|
let array_2_st = ndarray_safetensors::SafeArraysView::from_bytes(array_2)
|
||||||
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||||
|
|
||||||
|
let array_view_1 = array_1_st
|
||||||
|
.tensor::<f32, ndarray::Ix1>("embedding")
|
||||||
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||||
|
let array_view_2 = array_2_st
|
||||||
|
.tensor::<f32, ndarray::Ix1>("embedding")
|
||||||
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||||
|
|
||||||
|
let similarity = array_view_1
|
||||||
|
.cosine_similarity(array_view_2)
|
||||||
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||||
|
|
||||||
|
Ok(similarity)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.change_context(Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
11
src/main.rs
11
src/main.rs
@@ -14,7 +14,7 @@ const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx")
|
|||||||
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
|
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
|
||||||
pub fn main() -> Result<()> {
|
pub fn main() -> Result<()> {
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
.with_env_filter("error")
|
.with_env_filter("info")
|
||||||
.with_thread_ids(true)
|
.with_thread_ids(true)
|
||||||
.with_thread_names(true)
|
.with_thread_names(true)
|
||||||
.with_target(false)
|
.with_target(false)
|
||||||
@@ -319,12 +319,9 @@ fn run_query(query: cli::Query) -> Result<()> {
|
|||||||
embedding.model_name,
|
embedding.model_name,
|
||||||
embedding.created_at
|
embedding.created_at
|
||||||
);
|
);
|
||||||
if query.show_embeddings {
|
// if query.show_embeddings {
|
||||||
println!(
|
// println!(" Values: {:?}", &embedding.embedding);
|
||||||
" Values: {:?}",
|
// }
|
||||||
&embedding.embedding[..embedding.embedding.len().min(10)]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user