feat: save safetensors to the database
This commit is contained in:
@@ -37,7 +37,7 @@ nalgebra = { workspace = true }
|
||||
ndarray = "0.16.1"
|
||||
ndarray-image = { workspace = true }
|
||||
ndarray-resize = { workspace = true }
|
||||
rusqlite = { version = "0.37.0", features = ["modern-full"] }
|
||||
rusqlite = { version = "0.37.0", features = ["functions", "modern-full"] }
|
||||
tap = "1.0.1"
|
||||
thiserror = "2.0"
|
||||
tokio = "1.43.1"
|
||||
|
||||
@@ -114,16 +114,17 @@
|
||||
stdenv = p: p.clangStdenv;
|
||||
doCheck = false;
|
||||
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
|
||||
ORT_LIB_LOCATION = "${patchedOnnxruntime}";
|
||||
ORT_ENV_SYSTEM_LIB_LOCATION = "${patchedOnnxruntime}/lib";
|
||||
ORT_ENV_PREFER_DYNAMIC_LINK = true;
|
||||
# ORT_LIB_LOCATION = "${patchedOnnxruntime}";
|
||||
# ORT_ENV_SYSTEM_LIB_LOCATION = "${patchedOnnxruntime}/lib";
|
||||
# ORT_ENV_PREFER_DYNAMIC_LINK = true;
|
||||
nativeBuildInputs = with pkgs; [
|
||||
cmake
|
||||
pkg-config
|
||||
];
|
||||
buildInputs = with pkgs;
|
||||
[
|
||||
# onnxruntime
|
||||
patchedOnnxruntime
|
||||
sqlite
|
||||
]
|
||||
++ (lib.optionals pkgs.stdenv.isDarwin [
|
||||
libiconv
|
||||
|
||||
4
justfile
4
justfile
@@ -9,5 +9,5 @@ open:
|
||||
bench:
|
||||
cargo build --release
|
||||
BINARY="" hyperfine --warmup 3 --export-markdown benchmark.md \
|
||||
"$CARGO_TARGET_DIR/release/detector detect -f coreml selfie.jpg" \
|
||||
"$CARGO_TARGET_DIR/release/detector detect -f coreml -b 16 selfie.jpg"
|
||||
"$CARGO_TARGET_DIR/release/detector detect -f cpu selfie.jpg" \
|
||||
"$CARGO_TARGET_DIR/release/detector detect -f cpu -b 1 selfie.jpg"
|
||||
|
||||
@@ -68,17 +68,17 @@ use safetensors::tensor::SafeTensors;
|
||||
/// let view = SafeArrayView::from_bytes(&bytes).unwrap();
|
||||
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
|
||||
/// ```
|
||||
pub struct SafeArrayView<'a> {
|
||||
pub struct SafeArraysView<'a> {
|
||||
pub tensors: SafeTensors<'a>,
|
||||
}
|
||||
|
||||
impl<'a> SafeArrayView<'a> {
|
||||
impl<'a> SafeArraysView<'a> {
|
||||
fn new(tensors: SafeTensors<'a>) -> Self {
|
||||
Self { tensors }
|
||||
}
|
||||
|
||||
/// Create a SafeArrayView from serialized bytes
|
||||
pub fn from_bytes(bytes: &'a [u8]) -> Result<SafeArrayView<'a>> {
|
||||
pub fn from_bytes(bytes: &'a [u8]) -> Result<SafeArraysView<'a>> {
|
||||
let tensors = SafeTensors::deserialize(bytes)?;
|
||||
Ok(Self::new(tensors))
|
||||
}
|
||||
@@ -413,7 +413,7 @@ fn test_serialize_safe_arrays() {
|
||||
assert!(!serialized.is_empty());
|
||||
|
||||
// Deserialize to check if it works
|
||||
let deserialized = SafeArrayView::from_bytes(&serialized).unwrap();
|
||||
let deserialized = SafeArraysView::from_bytes(&serialized).unwrap();
|
||||
assert_eq!(deserialized.len(), 2);
|
||||
assert_eq!(
|
||||
deserialized
|
||||
|
||||
137
src/database.rs
137
src/database.rs
@@ -2,6 +2,7 @@ use crate::errors::{Error, Result};
|
||||
use crate::facedet::{FaceDetectionOutput, FaceLandmarks};
|
||||
use bounding_box::Aabb2;
|
||||
use error_stack::ResultExt;
|
||||
use ndarray_math::CosineSimilarity;
|
||||
use rusqlite::{Connection, OptionalExtension, params};
|
||||
use std::path::Path;
|
||||
|
||||
@@ -55,7 +56,7 @@ pub struct LandmarkRecord {
|
||||
pub struct EmbeddingRecord {
|
||||
pub id: i64,
|
||||
pub face_id: i64,
|
||||
pub embedding: Vec<f32>,
|
||||
pub embedding: ndarray::Array1<f32>,
|
||||
pub model_name: String,
|
||||
pub created_at: String,
|
||||
}
|
||||
@@ -64,6 +65,7 @@ impl FaceDatabase {
|
||||
/// Create a new database connection and initialize tables
|
||||
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
|
||||
let conn = Connection::open(db_path).change_context(Error)?;
|
||||
add_sqlite_cosine_similarity(&conn).change_context(Error)?;
|
||||
let db = Self { conn };
|
||||
db.create_tables()?;
|
||||
Ok(db)
|
||||
@@ -431,10 +433,16 @@ impl FaceDatabase {
|
||||
let embedding_iter = stmt
|
||||
.query_map(params![face_id], |row| {
|
||||
let embedding_bytes: Vec<u8> = row.get(2)?;
|
||||
let embedding: Vec<f32> = embedding_bytes
|
||||
.chunks_exact(4)
|
||||
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
|
||||
.collect();
|
||||
let embedding: ndarray::Array1<f32> = {
|
||||
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
|
||||
.change_context(Error)
|
||||
// .change_context(Error)?
|
||||
.unwrap();
|
||||
sf.tensor::<f32, ndarray::Ix1>("embedding")
|
||||
// .change_context(Error)?
|
||||
.unwrap()
|
||||
.to_owned()
|
||||
};
|
||||
|
||||
Ok(EmbeddingRecord {
|
||||
id: row.get(0)?,
|
||||
@@ -454,46 +462,6 @@ impl FaceDatabase {
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
/// Find similar faces by embedding (using cosine similarity)
|
||||
pub fn find_similar_faces(
|
||||
&self,
|
||||
query_embedding: &[f32],
|
||||
threshold: f32,
|
||||
limit: usize,
|
||||
) -> Result<Vec<(i64, f32)>> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("SELECT face_id, embedding FROM embeddings")
|
||||
.change_context(Error)?;
|
||||
|
||||
let embedding_iter = stmt
|
||||
.query_map([], |row| {
|
||||
let face_id: i64 = row.get(0)?;
|
||||
let embedding_bytes: Vec<u8> = row.get(1)?;
|
||||
let embedding: Vec<f32> = embedding_bytes
|
||||
.chunks_exact(4)
|
||||
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
|
||||
.collect();
|
||||
Ok((face_id, embedding))
|
||||
})
|
||||
.change_context(Error)?;
|
||||
|
||||
let mut similarities = Vec::new();
|
||||
for result in embedding_iter {
|
||||
let (face_id, embedding) = result.change_context(Error)?;
|
||||
let similarity = cosine_similarity(query_embedding, &embedding);
|
||||
if similarity >= threshold {
|
||||
similarities.push((face_id, similarity));
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by similarity (descending) and limit results
|
||||
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
similarities.truncate(limit);
|
||||
|
||||
Ok(similarities)
|
||||
}
|
||||
|
||||
/// Get database statistics
|
||||
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
|
||||
let images: usize = self
|
||||
@@ -518,6 +486,85 @@ impl FaceDatabase {
|
||||
|
||||
Ok((images, faces, landmarks, embeddings))
|
||||
}
|
||||
|
||||
/// Find similar faces based on cosine similarity of embeddings
|
||||
/// Return ids and similarity scores of similar faces
|
||||
pub fn find_similar_faces(
|
||||
&self,
|
||||
embedding: &ndarray::Array1<f32>,
|
||||
threshold: f32,
|
||||
limit: usize,
|
||||
) -> Result<Vec<(i64, f32)>> {
|
||||
// Serialize the query embedding to bytes
|
||||
let embedding_bytes =
|
||||
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
|
||||
.change_context(Error)?
|
||||
.serialize()
|
||||
.change_context(Error)?;
|
||||
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#" SELECT face_id, cosine_similarity(?1, embedding) as similarity
|
||||
FROM embeddings
|
||||
WHERE cosine_similarity(?1, embedding) >= ?2
|
||||
ORDER BY similarity DESC
|
||||
LIMIT ?3"#,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
let result = stmt
|
||||
.query_map(params![embedding_bytes, threshold, limit], |row| {
|
||||
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
|
||||
})
|
||||
.change_context(Error)?
|
||||
.map(|r| r.change_context(Error))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
// let mut results = Vec::new();
|
||||
// for result in result_iter {
|
||||
// results.push(result.change_context(Error)?);
|
||||
// }
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
|
||||
use rusqlite::functions::*;
|
||||
db.create_scalar_function(
|
||||
"cosine_similarity",
|
||||
2,
|
||||
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
|
||||
move |ctx| {
|
||||
if ctx.len() != 2 {
|
||||
return Err(rusqlite::Error::UserFunctionError(
|
||||
"cosine_similarity requires exactly 2 arguments".into(),
|
||||
));
|
||||
}
|
||||
let array_1 = ctx.get_raw(0).as_blob()?;
|
||||
let array_2 = ctx.get_raw(1).as_blob()?;
|
||||
|
||||
let array_1_st = ndarray_safetensors::SafeArraysView::from_bytes(array_1)
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
let array_2_st = ndarray_safetensors::SafeArraysView::from_bytes(array_2)
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
|
||||
let array_view_1 = array_1_st
|
||||
.tensor::<f32, ndarray::Ix1>("embedding")
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
let array_view_2 = array_2_st
|
||||
.tensor::<f32, ndarray::Ix1>("embedding")
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
|
||||
let similarity = array_view_1
|
||||
.cosine_similarity(array_view_2)
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
|
||||
Ok(similarity)
|
||||
},
|
||||
)
|
||||
.change_context(Error)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
11
src/main.rs
11
src/main.rs
@@ -14,7 +14,7 @@ const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx")
|
||||
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
|
||||
pub fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("error")
|
||||
.with_env_filter("info")
|
||||
.with_thread_ids(true)
|
||||
.with_thread_names(true)
|
||||
.with_target(false)
|
||||
@@ -319,12 +319,9 @@ fn run_query(query: cli::Query) -> Result<()> {
|
||||
embedding.model_name,
|
||||
embedding.created_at
|
||||
);
|
||||
if query.show_embeddings {
|
||||
println!(
|
||||
" Values: {:?}",
|
||||
&embedding.embedding[..embedding.embedding.len().min(10)]
|
||||
);
|
||||
}
|
||||
// if query.show_embeddings {
|
||||
// println!(" Values: {:?}", &embedding.embedding);
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user