Compare commits

..

33 Commits

Author SHA1 Message Date
uttarayan21
65560825fa feat: add cargo-outdated and improve slider precision in app views
Some checks failed
build / checks-matrix (push) Successful in 19m24s
build / codecov (push) Failing after 19m27s
docs / docs (push) Failing after 28m47s
build / checks-build (push) Has been cancelled
2025-08-22 13:06:16 +05:30
uttarayan21
0a5dbaaadc refactor(gui): set fixed input dimensions for face detection 2025-08-21 18:52:58 +05:30
uttarayan21
3e14a16739 feat(gui): Added iced gui 2025-08-21 18:28:39 +05:30
uttarayan21
bfa389b497 feat(compare): add face comparison functionality with cosine similarity
Some checks failed
build / checks-matrix (push) Successful in 19m23s
build / codecov (push) Failing after 19m18s
docs / docs (push) Failing after 28m50s
build / checks-build (push) Has been cancelled
2025-08-21 17:34:07 +05:30
uttarayan21
f8122892e0 feat(ndarray-safetensors): add tensor_by_index method for SafeArraysView
Some checks failed
build / checks-matrix (push) Successful in 19m24s
build / codecov (push) Failing after 19m27s
docs / docs (push) Failing after 28m51s
build / checks-build (push) Has been cancelled
2025-08-20 16:05:18 +05:30
uttarayan21
97f64e7e10 feat: save safetensors to the database
Some checks failed
build / checks-matrix (push) Successful in 19m23s
build / codecov (push) Failing after 19m26s
docs / docs (push) Failing after 28m47s
build / checks-build (push) Has been cancelled
2025-08-20 12:17:18 +05:30
uttarayan21
37adb74adf feat: Save tensors to database as safetensor 2025-08-20 12:17:18 +05:30
uttarayan21
47218fa696 feat: Added ndarray-safetensors 2025-08-20 12:17:16 +05:30
uttarayan21
61466c9edd refactor(mnn): remove unused model loading methods from mnn.rs files
Some checks failed
build / checks-matrix (push) Successful in 19m22s
build / codecov (push) Failing after 19m26s
docs / docs (push) Failing after 28m50s
build / checks-build (push) Has been cancelled
2025-08-20 01:41:55 +05:30
uttarayan21
33798467ba fix(onnx): Use patched version of onnxruntime
Some checks failed
build / checks-matrix (push) Successful in 19m21s
build / codecov (push) Failing after 19m26s
docs / docs (push) Failing after 28m47s
build / checks-build (push) Has been cancelled
2025-08-19 15:28:38 +05:30
uttarayan21
3d56db687c feat: Added nix build support
Some checks failed
build / checks-matrix (push) Successful in 19m19s
build / codecov (push) Failing after 19m26s
docs / docs (push) Has been cancelled
build / checks-build (push) Has been cancelled
2025-08-19 14:31:02 +05:30
uttarayan21
cd12e97de3 feat: Added lfs
Some checks failed
build / checks-matrix (push) Successful in 19m24s
build / codecov (push) Failing after 19m26s
docs / docs (push) Failing after 28m47s
build / checks-build (push) Has been cancelled
2025-08-18 22:31:37 +05:30
uttarayan21
bd6520ce5a feat: Remove models
Some checks failed
build / checks-build (push) Has been cancelled
build / codecov (push) Has been cancelled
build / checks-matrix (push) Has been cancelled
docs / docs (push) Has been cancelled
2025-08-18 22:14:49 +05:30
uttarayan21
cd9c65ff6b feat: Remove git lfs 2025-08-18 22:14:33 +05:30
uttarayan21
cc26391610 chore: Reorder lfs entries
Some checks failed
build / checks-matrix (push) Has been cancelled
build / checks-build (push) Has been cancelled
build / codecov (push) Has been cancelled
docs / docs (push) Has been cancelled
2025-08-18 22:12:05 +05:30
uttarayan21
783320131a feat: Added stuff
Some checks failed
build / checks-build (push) Has been cancelled
build / checks-matrix (push) Has been cancelled
build / codecov (push) Has been cancelled
docs / docs (push) Has been cancelled
2025-08-18 22:10:29 +05:30
uttarayan21
7fc958b299 feat: Added more ort execution_provider
Some checks failed
build / checks-matrix (push) Failing after 19m0s
build / checks-build (push) Has been skipped
build / codecov (push) Failing after 19m3s
docs / docs (push) Failing after 28m31s
2025-08-18 16:31:16 +05:30
uttarayan21
3aa95a2ef5 feat: Added cli features for mnn and ort 2025-08-18 15:07:17 +05:30
uttarayan21
e7c9c38ed7 feat: implement the facenet implementation for ort 2025-08-18 13:20:55 +05:30
uttarayan21
5a1f4b9ef6 feat: Move facenet to same structure as facedet
Some checks failed
build / checks-matrix (push) Successful in 19m20s
build / codecov (push) Failing after 19m18s
docs / docs (push) Failing after 28m42s
build / checks-build (push) Has been cancelled
2025-08-18 12:59:35 +05:30
uttarayan21
087d841959 fix: Change the structure of builder
Some checks failed
build / checks-matrix (push) Successful in 19m21s
build / codecov (push) Failing after 19m18s
docs / docs (push) Has been cancelled
build / checks-build (push) Has been cancelled
2025-08-18 12:03:00 +05:30
uttarayan21
050e937408 feat: Changed the struct for retinaface
Some checks failed
build / checks-matrix (push) Has been cancelled
build / checks-build (push) Has been cancelled
build / codecov (push) Has been cancelled
docs / docs (push) Has been cancelled
2025-08-18 11:59:09 +05:30
uttarayan21
33afbfc2b8 feat: Added stuff
Some checks failed
build / checks-matrix (push) Successful in 19m25s
build / checks-build (push) Has been cancelled
build / codecov (push) Has been cancelled
docs / docs (push) Has been cancelled
2025-08-18 11:31:03 +05:30
uttarayan21
2d2309837f feat: Added stuff
Some checks failed
build / checks-matrix (push) Successful in 23m6s
build / codecov (push) Failing after 19m30s
docs / docs (push) Failing after 28m54s
build / checks-build (push) Has been cancelled
2025-08-13 18:08:03 +05:30
uttarayan21
f5740dc87f feat: Added .gitattributes and .gitignore 2025-08-08 15:19:59 +05:30
uttarayan21
3753e399b1 feat: Added models 2025-08-08 15:15:50 +05:30
uttarayan21
d52b69911f feat: Added facenet 2025-08-08 15:01:25 +05:30
uttarayan21
a3ea01b7b6 feat: Added facenet 2025-08-07 17:24:01 +05:30
uttarayan21
e60921b099 feat: Make nms return result if scores.len() != boxes.len() 2025-08-07 15:50:48 +05:30
uttarayan21
e91ae5b865 feat: Added a manual implementation of nms 2025-08-07 15:45:54 +05:30
uttarayan21
2c43f657aa backup: broken backup 2025-08-07 13:30:34 +05:30
uttarayan21
8d07b0846c feat: Working retinaface 2025-08-07 11:51:10 +05:30
uttarayan21
f7aae32caf broken: Remove the FaceDetectionConfig 2025-08-05 19:17:31 +05:30
52 changed files with 10467 additions and 922 deletions

4
.gitattributes vendored Normal file
View File

@@ -0,0 +1,4 @@
models/facenet.onnx filter=lfs diff=lfs merge=lfs -text
models/retinaface.mnn filter=lfs diff=lfs merge=lfs -text
models/retinaface.onnx filter=lfs diff=lfs merge=lfs -text
models/facenet.mnn filter=lfs diff=lfs merge=lfs -text

3
.gitignore vendored
View File

@@ -2,3 +2,6 @@
/target
.direnv
*.jpg
face_net.onnx
.DS_Store
*.cache

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "rfcs"]
path = rfcs
url = git@github.com:aftershootco/rfcs.git

4551
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,10 @@
[workspace]
members = ["ndarray-image", "ndarray-resize", ".", "bounding-box"]
members = ["ndarray-image", "ndarray-resize", ".", "bounding-box", "ndarray-safetensors", "sqlite3-safetensor-cosine"]
[workspace.package]
version = "0.1.0"
edition = "2024"
[patch."https://github.com/uttarayan21/mnn-rs"]
mnn = { path = "/Users/fs0c131y/Projects/aftershoot/mnn-rs" }
[workspace.dependencies]
ndarray-image = { path = "ndarray-image" }
ndarray-resize = { path = "ndarray-resize" }
@@ -22,6 +19,7 @@ mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0",
mnn-sync = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
"tracing",
], branch = "restructure-tensor-type" }
nalgebra = { version = "0.34.0", default-features = false, features = ["std"] }
[package]
name = "detector"
@@ -35,12 +33,11 @@ clap_complete = "4.5"
error-stack = "0.5"
fast_image_resize = "5.2.0"
image = "0.25.6"
linfa = "0.7.1"
nalgebra = "0.33.2"
nalgebra = { workspace = true }
ndarray = "0.16.1"
ndarray-image = { workspace = true }
ndarray-resize = { workspace = true }
rusqlite = { version = "0.37.0", features = ["modern-full"] }
rusqlite = { version = "0.37.0", features = ["functions", "modern-full"] }
tap = "1.0.1"
thiserror = "2.0"
tokio = "1.43.1"
@@ -53,6 +50,28 @@ bounding-box = { version = "0.1.0", path = "bounding-box" }
color = "0.3.1"
itertools = "0.14.0"
ordered-float = "5.0.0"
ort = { version = "2.0.0-rc.10", default-features = false, features = [ "std", "tracing", "ndarray"]}
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" }
sqlite3-safetensor-cosine = { version = "0.1.0", path = "sqlite3-safetensor-cosine" }
# GUI dependencies
iced = { version = "0.13", features = ["tokio", "image"] }
rfd = "0.15"
futures = "0.3"
imageproc = "0.25"
[profile.release]
debug = true
[features]
ort-cuda = ["ort/cuda"]
ort-coreml = ["ort/coreml"]
ort-tensorrt = ["ort/tensorrt"]
ort-tvm = ["ort/tvm"]
ort-openvino = ["ort/openvino"]
ort-directml = ["ort/directml"]
mnn-metal = ["mnn/metal"]
mnn-coreml = ["mnn/coreml"]
default = ["mnn-metal","mnn-coreml"]

202
GUI_DEMO.md Normal file
View File

@@ -0,0 +1,202 @@
# Face Detector GUI - Demo Documentation
## Overview
This document demonstrates the successful creation of a modern GUI with full image rendering capabilities for the face-detector project using iced.rs, a cross-platform GUI framework for Rust.
## What Was Built
### 🎯 Core Features Implemented
1. **Modern Tabbed Interface**
- Detection tab for single image face detection with visual results
- Comparison tab for face similarity comparison with side-by-side images
- Settings tab for model and parameter configuration
2. **Full Image Rendering System**
- Real-time image preview for selected input images
- Processed image display with bounding boxes drawn around detected faces
- Side-by-side comparison view for face matching
- Automatic image scaling and fitting within UI containers
- Support for displaying results from both MNN and ONNX backends
3. **File Management**
- Image file selection dialogs
- Output path selection for processed images
- Support for multiple image formats (jpg, jpeg, png, bmp, tiff, webp)
- Automatic image loading and display upon selection
4. **Real-time Parameter Control**
- Adjustable detection threshold (0.1-1.0)
- Adjustable NMS threshold (0.1-1.0)
- Model type selection (RetinaFace, YOLO)
- Execution backend selection (MNN CPU/Metal/CoreML, ONNX CPU)
5. **Progress Tracking**
- Status bar with current operation display
- Progress bar for long-running operations
- Processing time reporting
6. **Visual Results Display**
- Face count reporting with visual confirmation
- Processed images with red bounding boxes around detected faces
- Similarity scores with interpretation and color coding
- Error handling and display
- Before/after image comparison
## Architecture
### 🏗️ Project Structure
```
src/
├── gui/
│ ├── mod.rs # Module declarations
│ ├── app.rs # Main application logic
│ └── bridge.rs # Integration with face detection backend
├── bin/
│ └── gui.rs # GUI executable entry point
└── ... # Existing face detection modules
```
### 🔌 Integration Points
The GUI seamlessly integrates with your existing face detection infrastructure:
- **Backend Support**: Both MNN and ONNX Runtime backends
- **Model Support**: RetinaFace and YOLO models
- **Hardware Acceleration**: Metal, CoreML, and CPU execution
- **Database Integration**: Ready for face database operations
## Technical Highlights
### ⚡ Performance Features
1. **Asynchronous Operations**: All face detection operations run asynchronously to keep the UI responsive
2. **Memory Efficient**: Proper resource management for image processing
3. **Hardware Accelerated**: Full support for Metal and CoreML on macOS
### 🎨 User Experience
1. **Intuitive Design**: Clean, modern interface with logical tab organization
2. **Real-time Feedback**: Immediate visual feedback for all operations
3. **Error Handling**: User-friendly error messages and recovery
4. **Accessibility**: Proper contrast and sizing for readability
## Usage Examples
### Running the GUI
```bash
# Build and run the GUI
cargo run --bin gui
# Or build the binary
cargo build --bin gui --release
./target/release/gui
```
### Face Detection Workflow
1. **Select Image**: Click "Select Image" to choose an input image
- Image immediately appears in the "Original Image" preview
2. **Adjust Parameters**: Use sliders to fine-tune detection thresholds
3. **Choose Backend**: Select MNN or ONNX execution backend
4. **Run Detection**: Click "Detect Faces" to process the image
5. **View Visual Results**:
- Original image displayed on the left
- Processed image with red bounding boxes on the right
- Face count, processing time, and status information below
### Face Comparison Workflow
1. **Select Images**: Choose two images for comparison
- Both images appear side-by-side in the comparison view
- "First Image" and "Second Image" clearly labeled
2. **Configure Settings**: Adjust detection and comparison parameters
3. **Run Comparison**: Click "Compare Faces" to analyze similarity
4. **View Visual Results**:
- Both original images displayed side-by-side for easy comparison
- Similarity scores with automatic interpretation and color coding:
- **> 0.8**: Very likely the same person (green text)
- **0.6-0.8**: Possibly the same person (yellow text)
- **0.4-0.6**: Unlikely to be the same person (orange text)
- **< 0.4**: Very unlikely to be the same person (red text)
## Current Status
### ✅ Successfully Implemented
- [x] Complete GUI framework integration
- [x] Tabbed interface with three main sections
- [x] File dialogs for image selection
- [x] **Full image rendering and display system**
- [x] **Real-time image preview for selected inputs**
- [x] **Processed image display with bounding boxes**
- [x] **Side-by-side image comparison view**
- [x] Parameter controls with real-time updates
- [x] Asynchronous operation handling
- [x] Progress tracking and status reporting
- [x] Integration with existing face detection backend
- [x] Support for both MNN and ONNX backends
- [x] Error handling and user feedback
- [x] Cross-platform compatibility (tested on macOS)
### 🔧 Known Issues
1. **Array Bounds Error**: There's a runtime error in the RetinaFace implementation that needs debugging:
```
thread 'tokio-runtime-worker' panicked at src/facedet/retinaface.rs:178:22:
ndarray: index 43008 is out of bounds for array of shape [43008]
```
This appears to be related to the original face detection logic, not the GUI code.
### 🚀 Future Enhancements
1. ~~**Image Display**: Add image preview and result visualization~~ ✅ **COMPLETED**
2. **Batch Processing**: Support for processing multiple images
3. **Database Integration**: GUI for face database operations
4. **Export Features**: Save results in various formats
5. **Configuration Persistence**: Remember user settings
6. **Drag & Drop**: Direct image dropping support
7. **Zoom and Pan**: Advanced image viewing capabilities
8. **Landmark Visualization**: Display facial landmarks on detected faces
## Technical Dependencies
### New Dependencies Added
```toml
# GUI dependencies
iced = { version = "0.13", features = ["tokio", "image"] }
rfd = "0.15" # File dialogs
futures = "0.3" # Async utilities
imageproc = "0.25" # Image processing utilities
```
### Integration Approach
The GUI was designed as a thin layer over your existing face detection engine:
- **Minimal Changes**: Only added new modules, no modifications to existing detection logic
- **Clean Separation**: GUI logic is completely separate from core detection algorithms
- **Reusable Components**: Bridge pattern allows easy extension to new backends
- **Maintainable Code**: Clear module boundaries and consistent error handling
## Compilation and Testing
The GUI compiles successfully with only minor warnings and has been tested on macOS with Apple Silicon. The interface is responsive and all UI components work as expected.
### Build Output
```
Finished `dev` profile [unoptimized + debuginfo] target(s) in 1m 05s
Running `/target/debug/gui`
```
The application launches properly, displays the GUI interface, and responds to user interactions. The only runtime issue is in the underlying face detection algorithm, which is separate from the GUI implementation.
## Conclusion
The GUI implementation successfully provides a modern, user-friendly interface for your face detection system. It maintains the full power and flexibility of your existing CLI tool while making it accessible to non-technical users through an intuitive graphical interface.
The architecture is extensible and maintainable, making it easy to add new features and functionality as your face detection system evolves.

BIN
KD4_7131.CR2 Normal file

Binary file not shown.

27
Makefile.toml Normal file
View File

@@ -0,0 +1,27 @@
[tasks.convert_facenet]
command = "MNNConvert"
args = [
"-f",
"ONNX",
"--modelFile",
"models/facenet.onnx",
"--MNNModel",
"models/facenet.mnn",
"--fp16",
"--bizCode",
"MNN",
]
[tasks.convert_retinaface]
command = "MNNConvert"
args = [
"-f",
"ONNX",
"--modelFile",
"models/retinaface.onnx",
"--MNNModel",
"models/retinaface.mnn",
"--fp16",
"--bizCode",
"MNN",
]

228
README.md
View File

@@ -1,3 +1,227 @@
# Face Detection
# Face Detection and Embedding
Rust programs to do face detection and face embedding
A high-performance Rust implementation for face detection and face embedding generation using neural networks.
## Overview
This project provides a complete face detection and recognition pipeline with the following capabilities:
- **Face Detection**: Detect faces in images using RetinaFace model
- **Face Embedding**: Generate face embeddings using FaceNet model
- **Multiple Backends**: Support for both MNN and ONNX runtime execution
- **Hardware Acceleration**: Metal, CoreML, and OpenCL support on compatible platforms
- **Modular Design**: Workspace architecture with reusable components
## Features
- 🔍 **Accurate Face Detection** - Uses RetinaFace model for robust face detection
- 🧠 **Face Embeddings** - Generate 512-dimensional face embeddings with FaceNet
-**High Performance** - Optimized with hardware acceleration (Metal, CoreML)
- 🔧 **Flexible Configuration** - Adjustable detection thresholds and NMS parameters
- 📦 **Modular Architecture** - Reusable components for image processing and bounding boxes
- 🖼️ **Visual Output** - Draw bounding boxes on detected faces
## Architecture
The project is organized as a Rust workspace with the following components:
- **`detector`** - Main face detection and embedding application
- **`bounding-box`** - Geometric operations and drawing utilities for bounding boxes
- **`ndarray-image`** - Conversion utilities between ndarray and image formats
- **`ndarray-resize`** - Fast image resizing operations on ndarray data
## Models
The project includes pre-trained neural network models:
- **RetinaFace** - Face detection model (`.mnn` and `.onnx` formats)
- **FaceNet** - Face embedding model (`.mnn` and `.onnx` formats)
## Usage
### Basic Face Detection
```bash
# Detect faces using MNN backend (default)
cargo run --release detect path/to/image.jpg
# Detect faces using ONNX Runtime backend
cargo run --release detect --executor onnx path/to/image.jpg
# Save output with bounding boxes drawn
cargo run --release detect --output detected.jpg path/to/image.jpg
# Adjust detection sensitivity
cargo run --release detect --threshold 0.9 --nms-threshold 0.4 path/to/image.jpg
```
### Face Comparison
Compare faces between two images by computing and comparing their embeddings:
```bash
# Compare faces in two images
cargo run --release compare image1.jpg image2.jpg
# Compare with custom thresholds
cargo run --release compare --threshold 0.9 --nms-threshold 0.4 image1.jpg image2.jpg
# Use ONNX Runtime backend for comparison
cargo run --release compare -p cpu image1.jpg image2.jpg
# Use MNN with Metal acceleration
cargo run --release compare -f metal image1.jpg image2.jpg
```
The compare command will:
1. Detect all faces in both images
2. Generate embeddings for each detected face
3. Compute cosine similarity between all face pairs
4. Display similarity scores and the best match
5. Provide interpretation of the similarity scores:
- **> 0.8**: Very likely the same person
- **0.6-0.8**: Possibly the same person
- **0.4-0.6**: Unlikely to be the same person
- **< 0.4**: Very unlikely to be the same person
### Backend Selection
The project supports two inference backends:
- **MNN Backend** (default): High-performance inference framework with Metal/CoreML support
- **ONNX Runtime Backend**: Cross-platform ML inference with broad hardware support
```bash
# Use MNN backend with Metal acceleration (macOS)
cargo run --release detect --executor mnn --forward-type metal path/to/image.jpg
# Use ONNX Runtime backend
cargo run --release detect --executor onnx path/to/image.jpg
```
### Command Line Options
```bash
# Face detection with custom parameters
cargo run --release detect [OPTIONS] <IMAGE>
Options:
-m, --model <MODEL> Custom model path
-M, --model-type <MODEL_TYPE> Model type [default: retina-face]
-o, --output <OUTPUT> Output image path
-e, --executor <EXECUTOR> Inference backend [mnn, onnx]
-f, --forward-type <FORWARD_TYPE> MNN execution backend [default: cpu]
-t, --threshold <THRESHOLD> Detection threshold [default: 0.8]
-n, --nms-threshold <NMS_THRESHOLD> NMS threshold [default: 0.3]
```
### Quick Start
```bash
# Build the project
cargo build --release
# Run face detection on sample image
just run
# or
cargo run --release detect ./1000066593.jpg
```
## Hardware Acceleration
### MNN Backend
The MNN backend supports various execution backends:
- **CPU** - Default, works on all platforms
- **Metal** - macOS GPU acceleration
- **CoreML** - macOS/iOS neural engine acceleration
- **OpenCL** - Cross-platform GPU acceleration
```bash
# Use Metal acceleration on macOS
cargo run --release detect --executor mnn --forward-type metal path/to/image.jpg
# Use CoreML on macOS/iOS
cargo run --release detect --executor mnn --forward-type coreml path/to/image.jpg
```
### ONNX Runtime Backend
The ONNX Runtime backend automatically selects the best available execution provider based on your system configuration.
## Development
### Prerequisites
- Rust 2024 edition
- MNN runtime (automatically linked)
- ONNX runtime (for ONNX backend)
### Building
```bash
# Standard build
cargo build
# Release build with optimizations
cargo build --release
# Run tests
cargo test
```
### Project Structure
```
├── src/
│ ├── facedet/ # Face detection modules
│ │ ├── mnn/ # MNN backend implementations
│ │ ├── ort/ # ONNX Runtime backend implementations
│ │ └── postprocess.rs # Shared postprocessing logic
│ ├── faceembed/ # Face embedding modules
│ │ ├── mnn/ # MNN backend implementations
│ │ └── ort/ # ONNX Runtime backend implementations
│ ├── cli.rs # Command line interface
│ └── main.rs # Application entry point
├── models/ # Neural network models (.mnn and .onnx)
├── bounding-box/ # Bounding box utilities
├── ndarray-image/ # Image conversion utilities
└── ndarray-resize/ # Image resizing utilities
```
### Backend Architecture
The codebase is organized to support multiple inference backends:
- **Common interfaces**: `FaceDetector` and `FaceEmbedder` traits provide unified APIs
- **Shared postprocessing**: Common logic for anchor generation, NMS, and coordinate decoding
- **Backend-specific implementations**: Separate modules for MNN and ONNX Runtime
- **Modular design**: Easy to add new backends by implementing the common traits
## License
MIT License
## Dependencies
Key dependencies include:
- **MNN** - High-performance neural network inference framework (MNN backend)
- **ONNX Runtime** - Cross-platform ML inference (ORT backend)
- **ndarray** - N-dimensional array processing
- **image** - Image processing and I/O
- **clap** - Command line argument parsing
- **bounding-box** - Geometric operations for face detection
- **error-stack** - Structured error handling
### Backend Status
-**MNN Backend**: Fully implemented with hardware acceleration support
- 🚧 **ONNX Runtime Backend**: Framework implemented, inference logic to be completed
*Note: The ORT backend currently provides the framework but requires completion of the inference implementation.*
---
*Built with Rust for maximum performance and safety in computer vision applications.*

1
assets/headshots Symbolic link
View File

@@ -0,0 +1 @@
/Users/fs0c131y/Pictures/test_cases/compressed/HeadshotJpeg

View File

@@ -6,12 +6,16 @@ edition = "2024"
[dependencies]
color = "0.3.1"
itertools = "0.14.0"
nalgebra = "0.33.2"
nalgebra = { workspace = true }
ndarray = { version = "0.16.1", optional = true }
num = "0.4.3"
ordered-float = "5.0.0"
simba = "0.9.0"
thiserror = "2.0.12"
tracing = { version = "0.1.41", optional = true, default-features = false }
[features]
ndarray = ["dep:ndarray"]
default = ["ndarray"]
tracing = ["dep:tracing"]
default = ["ndarray", "tracing"]

View File

@@ -4,11 +4,11 @@ pub use color::Rgba8;
use ndarray::{Array1, Array3, ArrayViewMut3};
pub trait Draw<T> {
fn draw(&mut self, item: T, color: color::Rgba8, thickness: usize);
fn draw(&mut self, item: &T, color: color::Rgba8, thickness: usize);
}
impl Draw<Aabb2<usize>> for Array3<u8> {
fn draw(&mut self, item: Aabb2<usize>, color: color::Rgba8, thickness: usize) {
fn draw(&mut self, item: &Aabb2<usize>, color: color::Rgba8, thickness: usize) {
item.draw(self, color, thickness)
}
}
@@ -65,8 +65,9 @@ impl Drawable<Array3<u8>> for Aabb2<usize> {
pixel.assign(&color);
})
})
.inspect_err(|e| {
dbg!(e);
.inspect_err(|_e| {
#[cfg(feature = "tracing")]
tracing::error!("{_e}")
})
.ok();
});

View File

@@ -2,9 +2,38 @@ pub mod draw;
pub mod nms;
pub mod roi;
use nalgebra::{Point, Point2, Point3, SVector, SimdPartialOrd, SimdValue};
pub trait Num: num::Num + Copy + core::fmt::Debug + 'static {}
impl<T: num::Num + Copy + core::fmt::Debug + 'static> Num for T {}
use nalgebra::{Point, Point2, SVector, Vector2};
pub trait Num:
num::Num
+ core::ops::AddAssign
+ core::ops::SubAssign
+ core::ops::MulAssign
+ core::ops::DivAssign
+ core::cmp::PartialOrd
+ core::cmp::PartialEq
+ nalgebra::SimdPartialOrd
+ nalgebra::SimdValue
+ Copy
+ core::fmt::Debug
+ 'static
{
}
impl<
T: num::Num
+ core::ops::AddAssign
+ core::ops::SubAssign
+ core::ops::MulAssign
+ core::ops::DivAssign
+ core::cmp::PartialOrd
+ core::cmp::PartialEq
+ nalgebra::SimdPartialOrd
+ nalgebra::SimdValue
+ Copy
+ core::fmt::Debug
+ 'static,
> Num for T
{
}
/// An axis aligned bounding box in `D` dimensions, defined by the minimum vertex and a size vector.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
@@ -20,16 +49,27 @@ pub type Aabb2<T> = AxisAlignedBoundingBox<T, 2>;
pub type Aabb3<T> = AxisAlignedBoundingBox<T, 3>;
impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
pub fn new(point: Point<T, D>, size: SVector<T, D>) -> Self {
// Panics if max < min
pub fn new(min_point: Point<T, D>, max_point: Point<T, D>) -> Self {
if max_point >= min_point {
Self::from_min_max_vertices(min_point, max_point)
} else {
panic!("max_point must be greater than or equal to min_point");
}
}
pub fn try_new(min_point: Point<T, D>, max_point: Point<T, D>) -> Option<Self> {
if max_point < min_point {
return None;
}
Some(Self::from_min_max_vertices(min_point, max_point))
}
pub fn new_point_size(point: Point<T, D>, size: SVector<T, D>) -> Self {
Self { point, size }
}
pub fn from_min_max_vertices(point1: Point<T, D>, point2: Point<T, D>) -> Self
where
T: core::ops::SubAssign,
{
let size = point2 - point1;
Self::new(point1, SVector::from(size))
pub fn from_min_max_vertices(min: Point<T, D>, max: Point<T, D>) -> Self {
let size = max - min;
Self::new_point_size(min, SVector::from(size))
}
/// Only considers the points closest and furthest from origin
@@ -151,7 +191,21 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
self.intersection(other)
}
pub fn union(&self, other: &Self) -> Self
pub fn component_clamp(&self, min: T, max: T) -> Self
where
T: PartialOrd,
{
let mut this = *self;
this.point.iter_mut().for_each(|x| {
*x = nalgebra::clamp(*x, min, max);
});
this.size.iter_mut().for_each(|x| {
*x = nalgebra::clamp(*x, min, max);
});
this
}
pub fn merge(&self, other: &Self) -> Self
where
T: core::ops::AddAssign,
T: core::ops::SubAssign,
@@ -159,13 +213,24 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
T: nalgebra::SimdValue,
T: nalgebra::SimdPartialOrd,
{
let self_min = self.min_vertex();
let self_max = self.max_vertex();
let other_min = other.min_vertex();
let other_max = other.max_vertex();
let min = self_min.inf(&other_min);
let max = self_max.sup(&other_max);
Self::from_min_max_vertices(min, max)
let min = self.min_vertex().inf(&other.min_vertex());
let max = self.min_vertex().sup(&other.max_vertex());
Self::new(min, max)
}
pub fn union(&self, other: &Self) -> T
where
T: core::ops::AddAssign,
T: core::ops::SubAssign,
T: core::ops::MulAssign,
T: PartialOrd,
T: nalgebra::SimdValue,
T: nalgebra::SimdPartialOrd,
{
self.measure() + other.measure()
- Self::intersection(self, other)
.map(|x| x.measure())
.unwrap_or(T::zero())
}
pub fn intersection(&self, other: &Self) -> Option<Self>
@@ -176,21 +241,9 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
T: nalgebra::SimdPartialOrd,
T: nalgebra::SimdValue,
{
let self_min = self.min_vertex();
let self_max = self.max_vertex();
let other_min = other.min_vertex();
let other_max = other.max_vertex();
if self_max < other_min || other_max < self_min {
return None; // No intersection
}
let min = self_min.sup(&other_min);
let max = self_max.inf(&other_max);
Some(Self::from_min_max_vertices(
Point::from(min),
Point::from(max),
))
let inter_min = self.min_vertex().sup(&other.min_vertex());
let inter_max = self.max_vertex().inf(&other.max_vertex());
Self::try_new(inter_min, inter_max)
}
pub fn denormalize(&self, factor: nalgebra::SVector<T, D>) -> Self
@@ -233,7 +286,7 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
self.size.product()
}
pub fn iou(&self, other: &Self) -> Option<T>
pub fn iou(&self, other: &Self) -> T
where
T: core::ops::AddAssign,
T: core::ops::SubAssign,
@@ -242,9 +295,19 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
T: nalgebra::SimdValue,
T: core::ops::MulAssign,
{
let intersection = self.intersection(other)?;
let union = self.union(other);
Some(intersection.measure() / union.measure())
let lhs_min = self.min_vertex();
let lhs_max = self.max_vertex();
let rhs_min = other.min_vertex();
let rhs_max = other.max_vertex();
let inter_min = lhs_min.sup(&rhs_min);
let inter_max = lhs_max.inf(&rhs_max);
if inter_max >= inter_min {
let intersection = Aabb::new(inter_min, inter_max).measure();
intersection / (self.measure() + other.measure() - intersection)
} else {
return T::zero();
}
}
}
@@ -255,15 +318,15 @@ impl<T: Num> Aabb2<T> {
{
let point1 = Point2::new(x1, y1);
let point2 = Point2::new(x2, y2);
Self::from_min_max_vertices(point1, point2)
Self::new(point1, point2)
}
pub fn new_2d(point1: Point2<T>, point2: Point2<T>) -> Self
where
T: core::ops::SubAssign,
{
let size = point2.coords - point1.coords;
Self::new(point1, SVector::from(size))
pub fn from_xywh(x: T, y: T, w: T, h: T) -> Self {
let point = Point2::new(x, y);
let size = Vector2::new(w, h);
Self::new_point_size(point, size)
}
pub fn x1y1(&self) -> Point2<T> {
self.point
}
@@ -327,14 +390,6 @@ impl<T: Num> Aabb2<T> {
}
impl<T: Num> Aabb3<T> {
pub fn new_3d(point1: Point3<T>, point2: Point3<T>) -> Self
where
T: core::ops::SubAssign,
{
let size = point2.coords - point1.coords;
Self::new(point1, SVector::from(size))
}
pub fn volume(&self) -> T
where
T: core::ops::MulAssign,
@@ -343,13 +398,16 @@ impl<T: Num> Aabb3<T> {
}
}
#[cfg(test)]
mod boudning_box_tests {
use super::*;
use nalgebra::*;
#[test]
fn test_bbox_new() {
use nalgebra::{Point2, Vector2};
let point1 = Point2::new(1.0, 2.0);
let point2 = Point2::new(4.0, 6.0);
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
let bbox = AxisAlignedBoundingBox::new(point1, point2);
assert_eq!(bbox.min_vertex(), point1);
assert_eq!(bbox.size(), Vector2::new(3.0, 4.0));
@@ -357,12 +415,24 @@ fn test_bbox_new() {
}
#[test]
fn test_bounding_box_center_2d() {
use nalgebra::{Point2, Vector2};
fn test_intersection_and_merge() {
let point1 = Point2::new(1, 5);
let point2 = Point2::new(3, 2);
let size1 = Vector2::new(3, 4);
let size2 = Vector2::new(1, 3);
let this = Aabb2::new_point_size(point1, size1);
let other = Aabb2::new_point_size(point2, size2);
let inter = this.intersection(&other);
let merged = this.merge(&other);
assert_ne!(inter, Some(merged))
}
#[test]
fn test_bounding_box_center_2d() {
let point = Point2::new(1.0, 2.0);
let size = Vector2::new(3.0, 4.0);
let bbox = AxisAlignedBoundingBox::new(point, size);
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
assert_eq!(bbox.min_vertex(), point);
assert_eq!(bbox.size(), size);
@@ -371,11 +441,9 @@ fn test_bounding_box_center_2d() {
#[test]
fn test_bounding_box_center_3d() {
use nalgebra::{Point3, Vector3};
let point = Point3::new(1.0, 2.0, 3.0);
let size = Vector3::new(4.0, 5.0, 6.0);
let bbox = AxisAlignedBoundingBox::new(point, size);
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
assert_eq!(bbox.min_vertex(), point);
assert_eq!(bbox.size(), size);
@@ -384,11 +452,9 @@ fn test_bounding_box_center_3d() {
#[test]
fn test_bounding_box_padding_2d() {
use nalgebra::{Point2, Vector2};
let point = Point2::new(1.0, 2.0);
let size = Vector2::new(3.0, 4.0);
let bbox = AxisAlignedBoundingBox::new(point, size);
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
let padded_bbox = bbox.padding(1.0);
assert_eq!(padded_bbox.min_vertex(), Point2::new(0.5, 1.5));
@@ -397,11 +463,9 @@ fn test_bounding_box_padding_2d() {
#[test]
fn test_bounding_box_scaling_2d() {
use nalgebra::{Point2, Vector2};
let point = Point2::new(1.0, 1.0);
let size = Vector2::new(3.0, 4.0);
let bbox = AxisAlignedBoundingBox::new(point, size);
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
let padded_bbox = bbox.scale(Vector2::new(2.0, 2.0));
assert_eq!(padded_bbox.min_vertex(), Point2::new(-2.0, -3.0));
@@ -410,11 +474,9 @@ fn test_bounding_box_scaling_2d() {
#[test]
fn test_bounding_box_contains_2d() {
use nalgebra::Point2;
let point1 = Point2::new(1.0, 2.0);
let point2 = Point2::new(4.0, 6.0);
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
let bbox = AxisAlignedBoundingBox::new(point1, point2);
assert!(bbox.contains_point(&Point2::new(2.0, 3.0)));
assert!(!bbox.contains_point(&Point2::new(5.0, 7.0)));
@@ -422,32 +484,28 @@ fn test_bounding_box_contains_2d() {
#[test]
fn test_bounding_box_union_2d() {
use nalgebra::{Point2, Vector2};
let point1 = Point2::new(1.0, 2.0);
let point2 = Point2::new(4.0, 6.0);
let bbox1 = AxisAlignedBoundingBox::new_2d(point1, point2);
let bbox1 = AxisAlignedBoundingBox::new(point1, point2);
let point3 = Point2::new(3.0, 5.0);
let point4 = Point2::new(7.0, 8.0);
let bbox2 = AxisAlignedBoundingBox::new_2d(point3, point4);
let bbox2 = AxisAlignedBoundingBox::new(point3, point4);
let union_bbox = bbox1.union(&bbox2);
let union_bbox = bbox1.merge(&bbox2);
assert_eq!(union_bbox.min_vertex(), Point2::new(1.0, 2.0));
assert_eq!(union_bbox.size(), Vector2::new(6.0, 6.0));
}
#[test]
fn test_bounding_box_intersection_2d() {
use nalgebra::{Point2, Vector2};
let point1 = Point2::new(1.0, 2.0);
let point2 = Point2::new(4.0, 6.0);
let bbox1 = AxisAlignedBoundingBox::new_2d(point1, point2);
let bbox1 = AxisAlignedBoundingBox::new(point1, point2);
let point3 = Point2::new(3.0, 5.0);
let point4 = Point2::new(5.0, 7.0);
let bbox2 = AxisAlignedBoundingBox::new_2d(point3, point4);
let bbox2 = AxisAlignedBoundingBox::new(point3, point4);
let intersection_bbox = bbox1.intersection(&bbox2).unwrap();
assert_eq!(intersection_bbox.min_vertex(), Point2::new(3.0, 5.0));
@@ -456,11 +514,9 @@ fn test_bounding_box_intersection_2d() {
#[test]
fn test_bounding_box_contains_point() {
use nalgebra::Point2;
let point1 = Point2::new(2, 3);
let point2 = Point2::new(5, 4);
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
let bbox = AxisAlignedBoundingBox::new(point1, point2);
use itertools::Itertools;
for (i, j) in (0..=10).cartesian_product(0..=10) {
if bbox.contains_point(&Point2::new(i, j)) {
@@ -496,3 +552,62 @@ fn test_bounding_box_clamp_box_2d() {
let expected = Aabb2::from_x1y1x2y2(5, 5, 7, 7);
assert_eq!(clamped, expected)
}
#[test]
fn test_iou_identical_boxes() {
let a = Aabb2::from_x1y1x2y2(1.0, 2.0, 4.0, 6.0);
let b = Aabb2::from_x1y1x2y2(1.0, 2.0, 4.0, 6.0);
assert_eq!(a.iou(&b), 1.0);
}
#[test]
fn test_iou_non_overlapping_boxes() {
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
let b = Aabb2::from_x1y1x2y2(2.0, 2.0, 3.0, 3.0);
assert_eq!(a.iou(&b), 0.0);
}
#[test]
fn test_iou_partial_overlap() {
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 2.0, 2.0);
let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 3.0, 3.0);
// Intersection area = 1, Union area = 7
assert!((a.iou(&b) - 1.0 / 7.0).abs() < 1e-6);
}
#[test]
fn test_iou_one_inside_another() {
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 4.0, 4.0);
let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 3.0, 3.0);
// Intersection area = 4, Union area = 16
assert!((a.iou(&b) - 0.25).abs() < 1e-6);
}
#[test]
fn test_iou_edge_touching() {
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
let b = Aabb2::from_x1y1x2y2(1.0, 0.0, 2.0, 1.0);
assert_eq!(a.iou(&b), 0.0);
}
#[test]
fn test_iou_corner_touching() {
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 2.0, 2.0);
assert_eq!(a.iou(&b), 0.0);
}
#[test]
fn test_iou_zero_area_box() {
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 0.0, 0.0);
let b = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
assert_eq!(a.iou(&b), 0.0);
}
#[test]
fn test_specific_values() {
let box1 = Aabb2::from_xywh(0.69482, 0.6716774, 0.07493961, 0.14968264);
let box2 = Aabb2::from_xywh(0.41546485, 0.70290875, 0.06197411, 0.08818436);
assert!(box1.iou(&box2) >= 0.0);
}
}

View File

@@ -1,4 +1,11 @@
use std::collections::HashSet;
use std::collections::{HashSet, VecDeque};
use itertools::Itertools;
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
pub enum NmsError {
#[error("Boxes and scores length mismatch (boxes: {boxes}, scores: {scores})")]
BoxesAndScoresLengthMismatch { boxes: usize, scores: usize },
}
use crate::*;
/// Apply Non-Maximum Suppression to a set of bounding boxes.
@@ -18,10 +25,11 @@ pub fn nms<T>(
scores: &[T],
score_threshold: T,
nms_threshold: T,
) -> HashSet<usize>
) -> Result<HashSet<usize>, NmsError>
where
T: Num
+ num::Float
+ ordered_float::FloatCore
+ core::ops::Neg<Output = T>
+ core::iter::Product<T>
+ core::ops::AddAssign
+ core::ops::SubAssign
@@ -29,56 +37,37 @@ where
+ nalgebra::SimdValue
+ nalgebra::SimdPartialOrd,
{
use itertools::Itertools;
// Create vector of (index, box, score) tuples for boxes with scores above threshold
let mut indexed_boxes: Vec<(usize, &Aabb2<T>, &T)> = boxes
if boxes.len() != scores.len() {
return Err(NmsError::BoxesAndScoresLengthMismatch {
boxes: boxes.len(),
scores: scores.len(),
});
}
let mut combined: VecDeque<(usize, Aabb2<T>, T, bool)> = boxes
.iter()
.enumerate()
.zip(scores.iter())
.zip(scores)
.filter_map(|((idx, bbox), score)| {
if *score >= score_threshold {
Some((idx, bbox, score))
} else {
None
}
(*score > score_threshold).then_some((idx, *bbox, *score, true))
})
.sorted_by_cached_key(|(_, _, score, _)| -ordered_float::OrderedFloat(*score))
.collect();
// Sort by score in descending order
indexed_boxes.sort_by(|(_, _, score_a), (_, _, score_b)| {
score_b
.partial_cmp(score_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut keep_indices = HashSet::new();
let mut suppressed = HashSet::new();
for (i, (idx_i, bbox_i, _)) in indexed_boxes.iter().enumerate() {
// Skip if this box is already suppressed
if suppressed.contains(idx_i) {
for i in 0..combined.len() {
let first = combined[i];
if first.3 == false {
continue;
}
// Keep this box
keep_indices.insert(*idx_i);
// Compare with remaining boxes
for (idx_j, bbox_j, _) in indexed_boxes.iter().skip(i + 1) {
// Skip if this box is already suppressed
if suppressed.contains(idx_j) {
continue;
}
// Calculate IoU and suppress if above threshold
if let Some(iou) = bbox_i.iou(bbox_j) {
if iou >= nms_threshold {
suppressed.insert(*idx_j);
}
let bbox = first.1;
for item in combined.iter_mut().skip(i + 1) {
if bbox.iou(&item.1) > nms_threshold {
item.3 = false
}
}
}
keep_indices
Ok(combined
.into_iter()
.filter_map(|(idx, _, _, keep)| keep.then_some(idx))
.collect())
}

View File

@@ -5,10 +5,17 @@ pub trait Roi<'a, Output> {
type Error;
fn roi(&'a self, aabb: Aabb2<usize>) -> Result<Output, Self::Error>;
}
pub trait RoiMut<'a, Output> {
type Error;
fn roi_mut(&'a mut self, aabb: Aabb2<usize>) -> Result<Output, Self::Error>;
}
pub trait MultiRoi<'a, Output> {
type Error;
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Output, Self::Error>;
}
#[derive(thiserror::Error, Debug, Copy, Clone)]
pub enum RoiError {
#[error("Region of intereset is out of bounds")]
@@ -36,7 +43,7 @@ impl<'a, T: Num> RoiMut<'a, ArrayViewMut3<'a, T>> for Array3<T> {
let x2 = aabb.x2();
let y1 = aabb.y1();
let y2 = aabb.y2();
if x1 >= x2 || y1 >= y2 || x2 > self.shape()[1] || y2 > self.shape()[0] {
if x1 > x2 || y1 > y2 || x2 > self.shape()[1] || y2 > self.shape()[0] {
return Err(RoiError::RoiOutOfBounds);
}
Ok(self.slice_mut(ndarray::s![y1..y2, x1..x2, ..]))
@@ -95,3 +102,47 @@ pub fn reborrow_test() {
};
dbg!(y);
}
impl<'a> MultiRoi<'a, Vec<ArrayView3<'a, u8>>> for Array3<u8> {
type Error = RoiError;
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Vec<ArrayView3<'a, u8>>, Self::Error> {
let (height, width, _channels) = self.dim();
let outer_aabb = Aabb2::from_x1y1x2y2(0, 0, width, height);
aabbs
.iter()
.map(|aabb| {
let slice_arg =
bbox_to_slice_arg(aabb.clamp(&outer_aabb).ok_or(RoiError::RoiOutOfBounds)?);
Ok(self.slice(slice_arg))
})
.collect::<Result<Vec<_>, RoiError>>()
}
}
impl<'a, 'b> MultiRoi<'a, Vec<ArrayView3<'b, u8>>> for ArrayView3<'b, u8> {
type Error = RoiError;
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Vec<ArrayView3<'b, u8>>, Self::Error> {
let (height, width, _channels) = self.dim();
let outer_aabb = Aabb2::from_x1y1x2y2(0, 0, width, height);
aabbs
.iter()
.map(|aabb| {
let slice_arg =
bbox_to_slice_arg(aabb.clamp(&outer_aabb).ok_or(RoiError::RoiOutOfBounds)?);
Ok(self.slice_move(slice_arg))
})
.collect::<Result<Vec<_>, RoiError>>()
}
}
fn bbox_to_slice_arg(
aabb: Aabb2<usize>,
) -> ndarray::SliceInfo<[ndarray::SliceInfoElem; 3], ndarray::Ix3, ndarray::Ix3> {
// This function should convert the bounding box to a slice argument
// For now, we will return a dummy value
let x1 = aabb.x1();
let x2 = aabb.x2();
let y1 = aabb.y1();
let y2 = aabb.y2();
ndarray::s![y1..y2, x1..x2, ..]
}

62
cr2.xmp Normal file
View File

@@ -0,0 +1,62 @@
<?xpacket begin='' id='W5M0MpCehiHzreSzNTczkc9d'?><x:xmpmeta xmlns:x="adobe:ns:meta/"><rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"><rdf:Description rdf:about="" xmlns:xmp="http://ns.adobe.com/xap/1.0/"><xmp:Rating>0</xmp:Rating></rdf:Description></rdf:RDF></x:xmpmeta>
<?xpacket end='w'?>

9
embedding.sql Normal file
View File

@@ -0,0 +1,9 @@
.load /Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib
SELECT
cosine_similarity(e1.embedding, e2.embedding) AS similarity
FROM
embeddings AS e1
CROSS JOIN embeddings AS e2
WHERE
e1.id = e2.id;

32
flake.lock generated
View File

@@ -3,11 +3,11 @@
"advisory-db": {
"flake": false,
"locked": {
"lastModified": 1750151065,
"narHash": "sha256-il+CAqChFIB82xP6bO43dWlUVs+NlG7a4g8liIP5HcI=",
"lastModified": 1755283329,
"narHash": "sha256-33bd+PHbon+cgEiWE/zkr7dpEF5E0DiHOzyoUQbkYBc=",
"owner": "rustsec",
"repo": "advisory-db",
"rev": "7573f55ba337263f61167dbb0ea926cdc7c8eb5d",
"rev": "61aac2116c8cb7cc80ff8ca283eec7687d384038",
"type": "github"
},
"original": {
@@ -18,11 +18,11 @@
},
"crane": {
"locked": {
"lastModified": 1750266157,
"narHash": "sha256-tL42YoNg9y30u7zAqtoGDNdTyXTi8EALDeCB13FtbQA=",
"lastModified": 1754269165,
"narHash": "sha256-0tcS8FHd4QjbCVoxN9jI+PjHgA4vc/IjkUSp+N3zy0U=",
"owner": "ipetkov",
"repo": "crane",
"rev": "e37c943371b73ed87faf33f7583860f81f1d5a48",
"rev": "444e81206df3f7d92780680e45858e31d2f07a08",
"type": "github"
},
"original": {
@@ -109,16 +109,16 @@
"mnn-src": {
"flake": false,
"locked": {
"lastModified": 1749173738,
"narHash": "sha256-pNljvQ4xMZ4VmuxQyXt+boNBZD0+UZNpNLrWrj8Rtfw=",
"lastModified": 1753256753,
"narHash": "sha256-aTpwVZBkpQiwOVVXDfQIVEx9CswNiPbvNftw8KsoW+Q=",
"owner": "alibaba",
"repo": "MNN",
"rev": "ebdada82634300956e08bd4056ecfeb1e4f23b32",
"rev": "a739ea5870a4a45680f0e36ba9662ca39f2f4eec",
"type": "github"
},
"original": {
"owner": "alibaba",
"ref": "3.2.0",
"ref": "3.2.2",
"repo": "MNN",
"type": "github"
}
@@ -145,11 +145,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1750506804,
"narHash": "sha256-VLFNc4egNjovYVxDGyBYTrvVCgDYgENp5bVi9fPTDYc=",
"lastModified": 1755186698,
"narHash": "sha256-wNO3+Ks2jZJ4nTHMuks+cxAiVBGNuEBXsT29Bz6HASo=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "4206c4cb56751df534751b058295ea61357bbbaa",
"rev": "fbcf476f790d8a217c3eab4e12033dc4a0f6d23c",
"type": "github"
},
"original": {
@@ -178,11 +178,11 @@
]
},
"locked": {
"lastModified": 1750732748,
"narHash": "sha256-HR2b3RHsPeJm+Fb+1ui8nXibgniVj7hBNvUbXEyz0DU=",
"lastModified": 1755485198,
"narHash": "sha256-C3042ST2lUg0nh734gmuP4lRRIBitA6Maegg2/jYRM4=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "4b4494b2ba7e8a8041b2e28320b2ee02c115c75f",
"rev": "aa45e63d431b28802ca4490cfc796b9e31731df7",
"type": "github"
},
"original": {

View File

@@ -22,7 +22,7 @@
inputs.nixpkgs.follows = "nixpkgs";
};
mnn-src = {
url = "github:alibaba/MNN/3.2.0";
url = "github:alibaba/MNN/3.2.2";
flake = false;
};
};
@@ -49,7 +49,7 @@
mnn = mnn-overlay.packages.${system}.mnn.override {
src = mnn-src;
buildConverter = true;
enableMetal = true;
enableMetal = pkgs.stdenv.isDarwin;
enableOpencl = true;
};
})
@@ -61,17 +61,44 @@
stableToolchain = pkgs.rust-bin.stable.latest.default;
stableToolchainWithLLvmTools = stableToolchain.override {
extensions = ["rust-src" "llvm-tools"];
extensions = [
"rust-src"
"llvm-tools"
];
};
stableToolchainWithRustAnalyzer = stableToolchain.override {
extensions = ["rust-src" "rust-analyzer"];
extensions = [
"rust-src"
"rust-analyzer"
];
};
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
ort_static = pkgs.onnxruntime.overrideAttrs (old: {
cmakeFlags =
old.cmakeFlags
++ [
"-Donnxruntime_BUILD_SHARED_LIB=OFF"
"-Donnxruntime_BUILD_STATIC_LIB=ON"
];
});
patchedOnnxruntime = pkgs.onnxruntime.overrideAttrs (old: {
patches = [./patches/ort_env_global_mutex.patch];
});
src = let
filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts;
sourceFilters = path: type: (craneLib.filterCargoSources path type) || filterBySuffix path [".c" ".h" ".hpp" ".cpp" ".cc"];
sourceFilters = path: type:
(craneLib.filterCargoSources path type)
|| filterBySuffix path [
".c"
".h"
".hpp"
".cpp"
".cc"
".mnn"
".onnx"
];
in
lib.cleanSourceWith {
filter = sourceFilters;
@@ -81,15 +108,21 @@
{
inherit src;
pname = name;
stdenv = pkgs.clangStdenv;
stdenv = p: p.clangStdenv;
doCheck = false;
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
# nativeBuildInputs = with pkgs; [
# cmake
# llvmPackages.libclang.lib
# ];
# ORT_LIB_LOCATION = "${patchedOnnxruntime}";
# ORT_ENV_SYSTEM_LIB_LOCATION = "${patchedOnnxruntime}/lib";
# ORT_ENV_PREFER_DYNAMIC_LINK = true;
nativeBuildInputs = with pkgs; [
cmake
pkg-config
];
buildInputs = with pkgs;
[]
[
patchedOnnxruntime
sqlite
]
++ (lib.optionals pkgs.stdenv.isDarwin [
libiconv
apple-sdk_13
@@ -102,11 +135,13 @@
in {
checks =
{
"${name}-clippy" = craneLib.cargoClippy (commonArgs
"${name}-clippy" = craneLib.cargoClippy (
commonArgs
// {
inherit cargoArtifacts;
cargoClippyExtraArgs = "--all-targets -- --deny warnings";
});
}
);
"${name}-docs" = craneLib.cargoDoc (commonArgs // {inherit cargoArtifacts;});
"${name}-fmt" = craneLib.cargoFmt {inherit src;};
"${name}-toml-fmt" = craneLib.taploFmt {
@@ -121,22 +156,29 @@
"${name}-deny" = craneLib.cargoDeny {
inherit src;
};
"${name}-nextest" = craneLib.cargoNextest (commonArgs
"${name}-nextest" = craneLib.cargoNextest (
commonArgs
// {
inherit cargoArtifacts;
partitions = 1;
partitionType = "count";
});
}
);
}
// lib.optionalAttrs (!pkgs.stdenv.isDarwin) {
"${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // {inherit cargoArtifacts;});
};
packages = let
pkg = craneLib.buildPackage (commonArgs
// {inherit cargoArtifacts;}
pkg = craneLib.buildPackage (
commonArgs
// {
nativeBuildInputs = with pkgs; [
inherit cargoArtifacts;
}
// {
nativeBuildInputs = with pkgs;
commonArgs.nativeBuildInputs
++ [
installShellFiles
];
postInstall = ''
@@ -145,28 +187,36 @@
--fish <($out/bin/${name} completions fish) \
--zsh <($out/bin/${name} completions zsh)
'';
});
}
);
in {
"${name}" = pkg;
default = pkg;
onnxruntime = ort_static;
};
devShells = {
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (commonArgs
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (
commonArgs
// {
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
packages = with pkgs;
[
stableToolchainWithRustAnalyzer
cargo-expand
cargo-outdated
cargo-nextest
cargo-deny
cmake
mnn
cargo-make
hyperfine
]
++ (lib.optionals pkgs.stdenv.isDarwin [
apple-sdk_13
]);
});
}
);
};
}
)

View File

@@ -1,2 +1,13 @@
run:
cargo run -r detect -- ./1000066593.jpg
run_onnx ep = "cpu" arg = "selfie.jpg":
cargo run -r detect -p {{ep}} -t 0.3 -o detected.jpg -- {{arg}}
run_mnn forward = "cpu" arg = "selfie.jpg":
cargo run -r detect -f {{forward}} -o detected.jpg -- {{arg}}
open:
open detected.jpg
bench:
cargo build --release
BINARY="" hyperfine --warmup 3 --export-markdown benchmark.md \
"$CARGO_TARGET_DIR/release/detector detect -f cpu selfie.jpg" \
"$CARGO_TARGET_DIR/release/detector detect -f cpu -b 1 selfie.jpg"

BIN
models/facenet.mnn LFS Normal file

Binary file not shown.

BIN
models/facenet.onnx LFS Normal file

Binary file not shown.

Binary file not shown.

BIN
models/retinaface.onnx LFS Normal file

Binary file not shown.

View File

@@ -5,7 +5,7 @@ fn shape_error() -> ndarray::ShapeError {
mod rgb8 {
use super::Result;
pub(super) fn image_as_ndarray(image: &image::RgbImage) -> Result<ndarray::ArrayView3<u8>> {
pub(super) fn image_as_ndarray(image: &image::RgbImage) -> Result<ndarray::ArrayView3<'_, u8>> {
let (width, height) = image.dimensions();
let data = image.as_raw();
ndarray::ArrayView3::from_shape((height as usize, width as usize, 3), data)
@@ -31,7 +31,9 @@ mod rgb8 {
mod rgba8 {
use super::Result;
pub(super) fn image_as_ndarray(image: &image::RgbaImage) -> Result<ndarray::ArrayView3<u8>> {
pub(super) fn image_as_ndarray(
image: &image::RgbaImage,
) -> Result<ndarray::ArrayView3<'_, u8>> {
let (width, height) = image.dimensions();
let data = image.as_raw();
ndarray::ArrayView3::from_shape((height as usize, width as usize, 4), data)
@@ -57,7 +59,9 @@ mod rgba8 {
mod gray8 {
use super::Result;
pub(super) fn image_as_ndarray(image: &image::GrayImage) -> Result<ndarray::ArrayView2<u8>> {
pub(super) fn image_as_ndarray(
image: &image::GrayImage,
) -> Result<ndarray::ArrayView2<'_, u8>> {
let (width, height) = image.dimensions();
let data = image.as_raw();
ndarray::ArrayView2::from_shape((height as usize, width as usize), data)
@@ -82,7 +86,7 @@ mod gray_alpha8 {
use super::Result;
pub(super) fn image_as_ndarray(
image: &image::GrayAlphaImage,
) -> Result<ndarray::ArrayView3<u8>> {
) -> Result<ndarray::ArrayView3<'_, u8>> {
let (width, height) = image.dimensions();
let data = image.as_raw();
ndarray::ArrayView3::from_shape((height as usize, width as usize, 2), data)
@@ -110,7 +114,7 @@ mod gray_alpha8 {
mod dynamic_image {
use super::*;
pub fn image_as_ndarray(image: &image::DynamicImage) -> Result<ndarray::ArrayViewD<u8>> {
pub fn image_as_ndarray(image: &image::DynamicImage) -> Result<ndarray::ArrayViewD<'_, u8>> {
Ok(match image {
image::DynamicImage::ImageRgb8(img) => rgb8::image_as_ndarray(img)?.into_dyn(),
image::DynamicImage::ImageRgba8(img) => rgba8::image_as_ndarray(img)?.into_dyn(),

View File

@@ -147,7 +147,7 @@ impl<S: ndarray::Data<Elem = T>, T: seal::Sealed + bytemuck::Pod, D: ndarray::Di
NdAsImage<T, D> for ndarray::ArrayBase<S, D>
{
/// Clones self and makes a new image
fn as_image_ref(&self) -> Result<ImageRef> {
fn as_image_ref(&self) -> Result<ImageRef<'_>> {
let shape = self.shape();
let rows = *shape
.first()

View File

@@ -0,0 +1,11 @@
[package]
name = "ndarray-safetensors"
version.workspace = true
edition.workspace = true
[dependencies]
bytemuck = { version = "1.23.2" }
half = { version = "2.6.0", default-features = false, features = ["bytemuck"] }
ndarray = { version = "0.16.1", default-features = false, features = ["std"] }
safetensors = "0.6.2"
thiserror = "2.0.15"

View File

@@ -0,0 +1,449 @@
//! # ndarray-serialize
//!
//! A Rust library for serializing and deserializing `ndarray` arrays using the SafeTensors format.
//!
//! ## Features
//! - Serialize `ndarray::ArrayView` to SafeTensors format
//! - Deserialize SafeTensors data back to `ndarray::ArrayView`
//! - Support for multiple data types (f32, f64, i8-i64, u8-u64, f16, bf16)
//! - Zero-copy deserialization when possible
//! - Metadata support
//!
//! ## Example
//! ```rust
//! use ndarray::Array2;
//! use ndarray_safetensors::{SafeArrays, SafeArrayView};
//!
//! // Create some data
//! let array = Array2::<f32>::zeros((3, 4));
//!
//! // Serialize
//! let mut safe_arrays = SafeArrays::new();
//! safe_arrays.insert_ndarray("my_tensor", array.view()).unwrap();
//! safe_arrays.insert_metadata("author", "example");
//! let bytes = safe_arrays.serialize().unwrap();
//!
//! // Deserialize
//! let view = SafeArrayView::from_bytes(&bytes).unwrap();
//! let tensor: ndarray::ArrayView2<f32> = view.tensor("my_tensor").unwrap();
//! assert_eq!(tensor.shape(), &[3, 4]);
//! ```
use safetensors::View;
use std::borrow::Cow;
use std::collections::{BTreeMap, HashMap};
use thiserror::Error;
/// Errors that can occur during SafeTensor operations
#[derive(Error, Debug)]
pub enum SafeTensorError {
#[error("Tensor not found: {0}")]
TensorNotFound(String),
#[error("Invalid tensor data: Got {0} Expected: {1}")]
InvalidTensorData(&'static str, String),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("Safetensor error: {0}")]
SafeTensor(#[from] safetensors::SafeTensorError),
#[error("ndarray::ShapeError error: {0}")]
NdarrayShapeError(#[from] ndarray::ShapeError),
}
type Result<T, E = SafeTensorError> = core::result::Result<T, E>;
use safetensors::tensor::SafeTensors;
/// A view into SafeTensors data that provides access to ndarray tensors
///
/// # Example
/// ```rust
/// use ndarray::Array2;
/// use ndarray_safetensors::{SafeArrays, SafeArrayView};
///
/// let array = Array2::<f32>::ones((2, 3));
/// let mut safe_arrays = SafeArrays::new();
/// safe_arrays.insert_ndarray("data", array.view()).unwrap();
/// let bytes = safe_arrays.serialize().unwrap();
///
/// let view = SafeArrayView::from_bytes(&bytes).unwrap();
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
/// ```
#[derive(Debug)]
pub struct SafeArraysView<'a> {
pub tensors: SafeTensors<'a>,
}
impl<'a> SafeArraysView<'a> {
fn new(tensors: SafeTensors<'a>) -> Self {
Self { tensors }
}
/// Create a SafeArrayView from serialized bytes
pub fn from_bytes(bytes: &'a [u8]) -> Result<SafeArraysView<'a>> {
let tensors = SafeTensors::deserialize(bytes)?;
Ok(Self::new(tensors))
}
/// Get a dynamic-dimensional tensor by name
pub fn dynamic_tensor<T: STDtype>(&self, name: &str) -> Result<ndarray::ArrayViewD<'a, T>> {
self.tensors
.tensor(name)
.map(|tensor| tensor_view_to_array_view(tensor))?
}
/// Get a tensor with specific dimensions by name
///
/// # Example
/// ```rust
/// # use ndarray::Array2;
/// # use ndarray_safetensors::{SafeArrays, SafeArrayView};
/// # let array = Array2::<f32>::ones((2, 3));
/// # let mut safe_arrays = SafeArrays::new();
/// # safe_arrays.insert_ndarray("data", array.view()).unwrap();
/// # let bytes = safe_arrays.serialize().unwrap();
/// # let view = SafeArrayView::from_bytes(&bytes).unwrap();
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
/// ```
pub fn tensor<T: STDtype, Dim: ndarray::Dimension>(
&self,
name: &str,
) -> Result<ndarray::ArrayView<'a, T, Dim>> {
Ok(self
.tensors
.tensor(name)
.map(|tensor| tensor_view_to_array_view(tensor))?
.map(|array_view| array_view.into_dimensionality::<Dim>())??)
}
pub fn tensor_by_index<T: STDtype, Dim: ndarray::Dimension>(
&self,
index: usize,
) -> Result<ndarray::ArrayView<'a, T, Dim>> {
self.tensors
.iter()
.nth(index)
.ok_or(SafeTensorError::TensorNotFound(format!(
"Index {} out of bounds",
index
)))
.map(|(_, tensor)| tensor_view_to_array_view(tensor))?
.map(|array_view| array_view.into_dimensionality::<Dim>())?
.map_err(SafeTensorError::NdarrayShapeError)
}
/// Get an iterator over tensor names
pub fn names(&self) -> std::vec::IntoIter<&str> {
self.tensors.names().into_iter()
}
/// Get the number of tensors
pub fn len(&self) -> usize {
self.tensors.len()
}
/// Check if there are no tensors
pub fn is_empty(&self) -> bool {
self.tensors.is_empty()
}
}
/// Trait for types that can be stored in SafeTensors
///
/// Implemented for: f32, f64, i8, i16, i32, i64, u8, u16, u32, u64, f16, bf16
pub trait STDtype: bytemuck::Pod {
fn dtype() -> safetensors::tensor::Dtype;
fn size() -> usize {
(Self::dtype().bitsize() / 8).max(1)
}
}
macro_rules! impl_dtype {
($($t:ty => $dtype:expr),* $(,)?) => {
$(
impl STDtype for $t {
fn dtype() -> safetensors::tensor::Dtype {
$dtype
}
}
)*
};
}
use safetensors::tensor::Dtype;
impl_dtype!(
// bool => Dtype::BOOL, // idk if ndarray::ArrayD<bool> is packed
f32 => Dtype::F32,
f64 => Dtype::F64,
i8 => Dtype::I8,
i16 => Dtype::I16,
i32 => Dtype::I32,
i64 => Dtype::I64,
u8 => Dtype::U8,
u16 => Dtype::U16,
u32 => Dtype::U32,
u64 => Dtype::U64,
half::f16 => Dtype::F16,
half::bf16 => Dtype::BF16,
);
fn tensor_view_to_array_view<'a, T: STDtype>(
tensor: safetensors::tensor::TensorView<'a>,
) -> Result<ndarray::ArrayViewD<'a, T>> {
let shape = tensor.shape();
let dtype = tensor.dtype();
if T::dtype() != dtype {
return Err(SafeTensorError::InvalidTensorData(
core::any::type_name::<T>(),
dtype.to_string(),
));
}
let data = tensor.data();
let data: &[T] = bytemuck::cast_slice(data);
let array = ndarray::ArrayViewD::from_shape(shape, data)?;
Ok(array)
}
/// Builder for creating SafeTensors data from ndarray tensors
///
/// # Example
/// ```rust
/// use ndarray::{Array1, Array2};
/// use ndarray_safetensors::SafeArrays;
///
/// let mut safe_arrays = SafeArrays::new();
///
/// let array1 = Array1::<f32>::from_vec(vec![1.0, 2.0, 3.0]);
/// let array2 = Array2::<i32>::zeros((2, 2));
///
/// safe_arrays.insert_ndarray("vector", array1.view()).unwrap();
/// safe_arrays.insert_ndarray("matrix", array2.view()).unwrap();
/// safe_arrays.insert_metadata("version", "1.0");
///
/// let bytes = safe_arrays.serialize().unwrap();
/// ```
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct SafeArrays<'a> {
pub tensors: BTreeMap<String, SafeArray<'a>>,
pub metadata: Option<HashMap<String, String>>,
}
impl<'a, K: AsRef<str>> FromIterator<(K, SafeArray<'a>)> for SafeArrays<'a> {
fn from_iter<T: IntoIterator<Item = (K, SafeArray<'a>)>>(iter: T) -> Self {
let tensors = iter
.into_iter()
.map(|(k, v)| (k.as_ref().to_owned(), v))
.collect();
Self {
tensors,
metadata: None,
}
}
}
impl<'a, K: AsRef<str>, T: IntoIterator<Item = (K, SafeArray<'a>)>> From<T> for SafeArrays<'a> {
fn from(iter: T) -> Self {
let tensors = iter
.into_iter()
.map(|(k, v)| (k.as_ref().to_owned(), v))
.collect();
Self {
tensors,
metadata: None,
}
}
}
impl<'a> SafeArrays<'a> {
/// Create a SafeArrays from an iterator of (name, ndarray::ArrayView) pairs
/// ```rust
/// use ndarray::{Array2, Array3};
/// use ndarray_safetensors::{SafeArrays, SafeArray};
/// let array = Array2::<f32>::zeros((3, 4));
/// let safe_arrays = SafeArrays::from_ndarrays(vec![
/// ("test_tensor", array.view()),
/// ("test_tensor2", array.view()),
/// ]).unwrap();
/// ```
pub fn from_ndarrays<
K: AsRef<str>,
T: STDtype,
D: ndarray::Dimension + 'a,
I: IntoIterator<Item = (K, ndarray::ArrayView<'a, T, D>)>,
>(
iter: I,
) -> Result<Self> {
let tensors = iter
.into_iter()
.map(|(k, v)| Ok((k.as_ref().to_owned(), SafeArray::from_ndarray(v)?)))
.collect::<Result<BTreeMap<String, SafeArray<'a>>>>()?;
Ok(Self {
tensors,
metadata: None,
})
}
}
// impl<'a, K: AsRef<str>, T: IntoIterator<Item = (K, SafeArray<'a>)>> From<T> for SafeArrays<'a> {
// fn from(iter: T) -> Self {
// let tensors = iter
// .into_iter()
// .map(|(k, v)| (k.as_ref().to_owned(), v))
// .collect();
// Self {
// tensors,
// metadata: None,
// }
// }
// }
impl<'a> SafeArrays<'a> {
/// Create a new empty SafeArrays builder
pub const fn new() -> Self {
Self {
tensors: BTreeMap::new(),
metadata: None,
}
}
/// Insert a SafeArray tensor with the given name
pub fn insert_tensor<'b: 'a>(&mut self, name: impl AsRef<str>, tensor: SafeArray<'b>) {
self.tensors.insert(name.as_ref().to_owned(), tensor);
}
/// Insert an ndarray tensor with the given name
///
/// The array must be in standard layout and contiguous.
pub fn insert_ndarray<'b: 'a, T: STDtype, D: ndarray::Dimension + 'a>(
&mut self,
name: impl AsRef<str>,
array: ndarray::ArrayView<'b, T, D>,
) -> Result<()> {
self.insert_tensor(name, SafeArray::from_ndarray(array)?);
Ok(())
}
/// Insert metadata key-value pair
pub fn insert_metadata(&mut self, key: impl AsRef<str>, value: impl AsRef<str>) {
self.metadata
.get_or_insert_default()
.insert(key.as_ref().to_owned(), value.as_ref().to_owned());
}
/// Serialize all tensors and metadata to bytes
pub fn serialize(self) -> Result<Vec<u8>> {
let out = safetensors::serialize(self.tensors, self.metadata)
.map_err(SafeTensorError::SafeTensor)?;
Ok(out)
}
}
/// A tensor that can be serialized to SafeTensors format
#[derive(Debug, Clone)]
pub struct SafeArray<'a> {
data: Cow<'a, [u8]>,
shape: Vec<usize>,
dtype: safetensors::tensor::Dtype,
}
impl View for SafeArray<'_> {
fn dtype(&self) -> safetensors::tensor::Dtype {
self.dtype
}
fn shape(&self) -> &[usize] {
&self.shape
}
fn data(&self) -> Cow<'_, [u8]> {
self.data.clone()
}
fn data_len(&self) -> usize {
self.data.len()
}
}
impl<'a> SafeArray<'a> {
fn from_ndarray<'b: 'a, T: STDtype, D: ndarray::Dimension + 'a>(
array: ndarray::ArrayView<'b, T, D>,
) -> Result<Self> {
let shape = array.shape().to_vec();
let dtype = T::dtype();
if array.ndim() == 0 {
return Err(SafeTensorError::InvalidTensorData(
core::any::type_name::<T>(),
"Cannot insert a scalar tensor".to_string(),
));
}
if !array.is_standard_layout() {
return Err(SafeTensorError::InvalidTensorData(
core::any::type_name::<T>(),
"ArrayView is not standard layout".to_string(),
));
}
let data =
bytemuck::cast_slice(array.to_slice().ok_or(SafeTensorError::InvalidTensorData(
core::any::type_name::<T>(),
"ArrayView is not contiguous".to_string(),
))?);
let safe_array = SafeArray {
data: Cow::Borrowed(data),
shape,
dtype,
};
Ok(safe_array)
}
}
#[test]
fn test_safe_array_from_ndarray() {
use ndarray::Array2;
let array = Array2::<f32>::zeros((3, 4));
let safe_array = SafeArray::from_ndarray(array.view()).unwrap();
assert_eq!(safe_array.shape, vec![3, 4]);
assert_eq!(safe_array.dtype, safetensors::tensor::Dtype::F32);
assert_eq!(safe_array.data.len(), 3 * 4 * 4); // 3x4x4 bytes for f32
}
#[test]
fn test_serialize_safe_arrays() {
use ndarray::{Array2, Array3};
let mut safe_arrays = SafeArrays::new();
let array = Array2::<f32>::zeros((3, 4));
let array2 = Array3::<u16>::zeros((8, 1, 9));
safe_arrays
.insert_ndarray("test_tensor", array.view())
.unwrap();
safe_arrays
.insert_ndarray("test_tensor2", array2.view())
.unwrap();
safe_arrays.insert_metadata("author", "example");
let serialized = safe_arrays.serialize().unwrap();
assert!(!serialized.is_empty());
// Deserialize to check if it works
let deserialized = SafeArraysView::from_bytes(&serialized).unwrap();
assert_eq!(deserialized.len(), 2);
assert_eq!(
deserialized
.tensor::<f32, ndarray::Ix2>("test_tensor")
.unwrap()
.shape(),
&[3, 4]
);
assert_eq!(
deserialized
.tensor::<u16, ndarray::Ix3>("test_tensor2")
.unwrap()
.shape(),
&[8, 1, 9]
);
}

View File

@@ -0,0 +1,42 @@
From 83e1dbf52b7695a2966795e0350aaa385d1ba8c8 Mon Sep 17 00:00:00 2001
From: "Carson M." <carson@pyke.io>
Date: Sun, 22 Jun 2025 23:53:20 -0500
Subject: [PATCH] Leak logger mutex
---
onnxruntime/core/common/logging/logging.cc | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/onnxruntime/core/common/logging/logging.cc b/onnxruntime/core/common/logging/logging.cc
index a79e7300cffce..07578fc72ca99 100644
--- a/onnxruntime/core/common/logging/logging.cc
+++ b/onnxruntime/core/common/logging/logging.cc
@@ -64,8 +64,8 @@ LoggingManager* LoggingManager::GetDefaultInstance() {
#pragma warning(disable : 26426)
#endif
-static std::mutex& DefaultLoggerMutex() noexcept {
- static std::mutex mutex;
+static std::mutex* DefaultLoggerMutex() noexcept {
+ static std::mutex* mutex = new std::mutex();
return mutex;
}
@@ -107,7 +107,7 @@ LoggingManager::LoggingManager(std::unique_ptr<ISink> sink, Severity default_min
// lock mutex to create instance, and enable logging
// this matches the mutex usage in Shutdown
- std::lock_guard<std::mutex> guard(DefaultLoggerMutex());
+ std::lock_guard<std::mutex> guard(*DefaultLoggerMutex());
if (DefaultLoggerManagerInstance().load() != nullptr) {
ORT_THROW("Only one instance of LoggingManager created with InstanceType::Default can exist at any point in time.");
@@ -127,7 +127,7 @@ LoggingManager::LoggingManager(std::unique_ptr<ISink> sink, Severity default_min
LoggingManager::~LoggingManager() {
if (owns_default_logger_) {
// lock mutex to reset DefaultLoggerManagerInstance() and free default logger from this instance.
- std::lock_guard<std::mutex> guard(DefaultLoggerMutex());
+ std::lock_guard<std::mutex> guard(*DefaultLoggerMutex());
#if ((__cplusplus >= 201703L) || (defined(_MSVC_LANG) && (_MSVC_LANG >= 201703L)))
DefaultLoggerManagerInstance().store(nullptr, std::memory_order_release);
#else

1
rfcs Submodule

Submodule rfcs added at c973203daf

View File

@@ -0,0 +1,14 @@
[package]
name = "sqlite3-safetensor-cosine"
version.workspace = true
edition.workspace = true
[lib]
crate-type = ["cdylib", "staticlib"]
[dependencies]
ndarray = "0.16.1"
# ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
ndarray-math = { path = "/Users/fs0c131y/Projects/ndarray-math", version = "0.1.0" }
ndarray-safetensors = { version = "0.1.0", path = "../ndarray-safetensors" }
sqlite-loadable = "0.0.5"

View File

@@ -0,0 +1,61 @@
use sqlite_loadable::prelude::*;
use sqlite_loadable::{Error, ErrorKind};
use sqlite_loadable::{Result, api, define_scalar_function};
fn cosine_similarity(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
#[inline(always)]
fn custom_error(err: impl core::error::Error) -> sqlite_loadable::Error {
sqlite_loadable::Error::new(sqlite_loadable::ErrorKind::Message(err.to_string()))
}
if values.len() != 2 {
return Err(Error::new(ErrorKind::Message(
"cosine_similarity requires exactly 2 arguments".to_string(),
)));
}
let array_1 = api::value_blob(values.get(0).expect("1st argument"));
let array_2 = api::value_blob(values.get(1).expect("2nd argument"));
let array_1_st =
ndarray_safetensors::SafeArraysView::from_bytes(array_1).map_err(custom_error)?;
let array_2_st =
ndarray_safetensors::SafeArraysView::from_bytes(array_2).map_err(custom_error)?;
let array_view_1 = array_1_st
.tensor_by_index::<f32, ndarray::Ix1>(0)
.map_err(custom_error)?;
let array_view_2 = array_2_st
.tensor_by_index::<f32, ndarray::Ix1>(0)
.map_err(custom_error)?;
use ndarray_math::*;
let similarity = array_view_1
.cosine_similarity(array_view_2)
.map_err(custom_error)?;
api::result_double(context, similarity as f64);
Ok(())
}
pub fn _sqlite3_extension_init(db: *mut sqlite3) -> Result<()> {
define_scalar_function(
db,
"cosine_similarity",
2,
cosine_similarity,
FunctionFlags::DETERMINISTIC,
)?;
Ok(())
}
/// # Safety
///
/// Should only be called by underlying SQLite C APIs,
/// like sqlite3_auto_extension and sqlite3_cancel_auto_extension.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sqlite3_extension_init(
db: *mut sqlite3,
pz_err_msg: *mut *mut c_char,
p_api: *mut sqlite3_api_routines,
) -> c_uint {
register_entrypoint(db, pz_err_msg, p_api, _sqlite3_extension_init)
}

195
src/bin/detector-cli/cli.rs Normal file
View File

@@ -0,0 +1,195 @@
use detector::ort_ep;
use std::path::PathBuf;
#[derive(Debug, clap::Parser)]
pub struct Cli {
#[clap(subcommand)]
pub cmd: SubCommand,
}
#[derive(Debug, clap::Subcommand)]
pub enum SubCommand {
#[clap(name = "detect")]
Detect(Detect),
#[clap(name = "detect-multi")]
DetectMulti(DetectMulti),
#[clap(name = "query")]
Query(Query),
#[clap(name = "similar")]
Similar(Similar),
#[clap(name = "stats")]
Stats(Stats),
#[clap(name = "compare")]
Compare(Compare),
#[clap(name = "gui")]
Gui,
#[clap(name = "completions")]
Completions { shell: clap_complete::Shell },
}
#[derive(Debug, clap::ValueEnum, Clone, Copy, PartialEq)]
pub enum Models {
RetinaFace,
Yolo,
}
#[derive(Debug, Clone)]
pub enum Executor {
Mnn(mnn::ForwardType),
Ort(Vec<ort_ep::ExecutionProvider>),
}
#[derive(Debug, clap::Args)]
pub struct Detect {
#[clap(short, long)]
pub model: Option<PathBuf>,
#[clap(short = 'M', long, default_value = "retina-face")]
pub model_type: Models,
#[clap(short, long)]
pub output: Option<PathBuf>,
#[clap(
short = 'p',
long,
default_value = "cpu",
group = "execution_provider",
required_unless_present = "mnn_forward_type"
)]
pub ort_execution_provider: Vec<ort_ep::ExecutionProvider>,
#[clap(
short = 'f',
long,
group = "execution_provider",
required_unless_present = "ort_execution_provider"
)]
pub mnn_forward_type: Option<mnn::ForwardType>,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)]
pub nms_threshold: f32,
#[clap(short, long, default_value_t = 8)]
pub batch_size: usize,
#[clap(short = 'd', long)]
pub database: Option<PathBuf>,
#[clap(long, default_value = "facenet")]
pub model_name: String,
#[clap(long)]
pub save_to_db: bool,
pub image: PathBuf,
}
#[derive(Debug, clap::Args)]
pub struct DetectMulti {
#[clap(short, long)]
pub model: Option<PathBuf>,
#[clap(short = 'M', long, default_value = "retina-face")]
pub model_type: Models,
#[clap(short, long)]
pub output_dir: Option<PathBuf>,
#[clap(
short = 'p',
long,
default_value = "cpu",
group = "execution_provider",
required_unless_present = "mnn_forward_type"
)]
pub ort_execution_provider: Vec<ort_ep::ExecutionProvider>,
#[clap(
short = 'f',
long,
group = "execution_provider",
required_unless_present = "ort_execution_provider"
)]
pub mnn_forward_type: Option<mnn::ForwardType>,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)]
pub nms_threshold: f32,
#[clap(short, long, default_value_t = 8)]
pub batch_size: usize,
#[clap(short = 'd', long, default_value = "face_detections.db")]
pub database: PathBuf,
#[clap(long, default_value = "facenet")]
pub model_name: String,
#[clap(
long,
help = "Image extensions to process (e.g., jpg,png,jpeg)",
default_value = "jpg,jpeg,png,bmp,tiff,webp"
)]
pub extensions: String,
#[clap(help = "Directory containing images to process")]
pub input_dir: PathBuf,
}
#[derive(Debug, clap::Args)]
pub struct Query {
#[clap(short = 'd', long, default_value = "face_detections.db")]
pub database: PathBuf,
#[clap(short, long)]
pub image_id: Option<i64>,
#[clap(short, long)]
pub face_id: Option<i64>,
#[clap(long)]
pub show_embeddings: bool,
#[clap(long)]
pub show_landmarks: bool,
}
#[derive(Debug, clap::Args)]
pub struct Similar {
#[clap(short = 'd', long, default_value = "face_detections.db")]
pub database: PathBuf,
#[clap(short, long)]
pub face_id: i64,
#[clap(short, long, default_value_t = 0.7)]
pub threshold: f32,
#[clap(short, long, default_value_t = 10)]
pub limit: usize,
}
#[derive(Debug, clap::Args)]
pub struct Stats {
#[clap(short = 'd', long, default_value = "face_detections.db")]
pub database: PathBuf,
}
#[derive(Debug, clap::Args)]
pub struct Compare {
#[clap(short, long)]
pub model: Option<PathBuf>,
#[clap(short = 'M', long, default_value = "retina-face")]
pub model_type: Models,
#[clap(
short = 'p',
long,
default_value = "cpu",
group = "execution_provider",
required_unless_present = "mnn_forward_type"
)]
pub ort_execution_provider: Vec<ort_ep::ExecutionProvider>,
#[clap(
short = 'f',
long,
group = "execution_provider",
required_unless_present = "ort_execution_provider"
)]
pub mnn_forward_type: Option<mnn::ForwardType>,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)]
pub nms_threshold: f32,
#[clap(short, long, default_value_t = 8)]
pub batch_size: usize,
#[clap(long, default_value = "facenet")]
pub model_name: String,
#[clap(help = "First image to compare")]
pub image1: PathBuf,
#[clap(help = "Second image to compare")]
pub image2: PathBuf,
}
impl Cli {
pub fn completions(shell: clap_complete::Shell) {
let mut command = <Cli as clap::CommandFactory>::command();
clap_complete::generate(shell, &mut command, "detector", &mut std::io::stdout());
}
}

View File

@@ -0,0 +1,936 @@
mod cli;
use bounding_box::roi::MultiRoi;
use detector::*;
use detector::{database::FaceDatabase, facedet, facedet::FaceDetectionConfig, faceembed};
use errors::*;
use fast_image_resize::ResizeOptions;
use ndarray::*;
use ndarray_image::*;
use ndarray_resize::NdFir;
const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/models/retinaface.mnn"
));
const FACENET_MODEL_MNN: &[u8] =
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/models/facenet.mnn"));
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/models/retinaface.onnx"
));
const FACENET_MODEL_ONNX: &[u8] =
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/models/facenet.onnx"));
pub fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter("info")
.with_thread_ids(true)
.with_thread_names(true)
.with_target(false)
.init();
let args = <cli::Cli as clap::Parser>::parse();
match args.cmd {
cli::SubCommand::Detect(detect) => {
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
let executor = detect
.mnn_forward_type
.map(|f| cli::Executor::Mnn(f))
.or_else(|| {
if detect.ort_execution_provider.is_empty() {
None
} else {
Some(cli::Executor::Ort(detect.ort_execution_provider.clone()))
}
})
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
match executor {
cli::Executor::Mnn(forward) => {
let retinaface =
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_detection(detect, retinaface, facenet)?;
}
cli::Executor::Ort(ep) => {
let retinaface =
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(&ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_detection(detect, retinaface, facenet)?;
}
}
}
cli::SubCommand::DetectMulti(detect_multi) => {
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
let executor = detect_multi
.mnn_forward_type
.map(|f| cli::Executor::Mnn(f))
.or_else(|| {
if detect_multi.ort_execution_provider.is_empty() {
None
} else {
Some(cli::Executor::Ort(
detect_multi.ort_execution_provider.clone(),
))
}
})
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
match executor {
cli::Executor::Mnn(forward) => {
let retinaface =
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_multi_detection(detect_multi, retinaface, facenet)?;
}
cli::Executor::Ort(ep) => {
let retinaface =
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(&ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_multi_detection(detect_multi, retinaface, facenet)?;
}
}
}
cli::SubCommand::Query(query) => {
run_query(query)?;
}
cli::SubCommand::Similar(similar) => {
run_similar(similar)?;
}
cli::SubCommand::Stats(stats) => {
run_stats(stats)?;
}
cli::SubCommand::Compare(compare) => {
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
let executor = compare
.mnn_forward_type
.map(|f| cli::Executor::Mnn(f))
.or_else(|| {
if compare.ort_execution_provider.is_empty() {
None
} else {
Some(cli::Executor::Ort(compare.ort_execution_provider.clone()))
}
})
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
match executor {
cli::Executor::Mnn(forward) => {
let retinaface =
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_compare(compare, retinaface, facenet)?;
}
cli::Executor::Ort(ep) => {
let retinaface =
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(&ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_compare(compare, retinaface, facenet)?;
}
}
}
cli::SubCommand::Gui => {
if let Err(e) = detector::gui::run() {
eprintln!("GUI error: {}", e);
std::process::exit(1);
}
}
cli::SubCommand::Completions { shell } => {
cli::Cli::completions(shell);
}
}
Ok(())
}
fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, mut facenet: E) -> Result<()>
where
D: facedet::FaceDetector,
E: faceembed::FaceEmbedder,
{
// Initialize database if requested
let db = if detect.save_to_db {
let db_path = detect
.database
.as_ref()
.map(|p| p.as_path())
.unwrap_or_else(|| std::path::Path::new("face_detections.db"));
Some(FaceDatabase::new(db_path).change_context(Error)?)
} else {
None
};
let image = image::open(&detect.image)
.change_context(Error)
.attach_printable(detect.image.to_string_lossy().to_string())?;
let image = image.into_rgb8();
let (image_width, image_height) = image.dimensions();
let mut array = image
.into_ndarray()
.change_context(errors::Error)
.attach_printable("Failed to convert image to ndarray")?;
let output = retinaface
.detect_faces(
array.view(),
&FaceDetectionConfig::default()
.with_threshold(detect.threshold)
.with_nms_threshold(detect.nms_threshold),
)
.change_context(errors::Error)
.attach_printable("Failed to detect faces")?;
// Store image and face detections in database if requested
let (_image_id, face_ids) = if let Some(ref database) = db {
let image_path = detect.image.to_string_lossy();
let img_id = database
.store_image(&image_path, image_width, image_height)
.change_context(Error)?;
let face_ids = database
.store_face_detections(img_id, &output)
.change_context(Error)?;
tracing::info!(
"Stored image {} with {} faces in database",
img_id,
face_ids.len()
);
(Some(img_id), Some(face_ids))
} else {
(None, None)
};
for bbox in &output.bbox {
tracing::info!("Detected face: {:?}", bbox);
use bounding_box::draw::*;
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
}
let face_rois = array
.view()
.multi_roi(&output.bbox)
.change_context(Error)?
.into_iter()
// .inspect(|f| {
// tracing::info!("Face ROI shape before resize: {:?}", f.dim());
// })
.map(|roi| {
roi.as_standard_layout()
.fast_resize(320, 320, &ResizeOptions::default())
.change_context(Error)
})
// .inspect(|f| {
// f.as_ref().inspect(|f| {
// tracing::info!("Face ROI shape after resize: {:?}", f.dim());
// });
// })
.collect::<Result<Vec<_>>>()?;
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
let chunk_size = detect.batch_size;
let embeddings = face_roi_views
.chunks(chunk_size)
.map(|chunk| {
tracing::info!("Processing chunk of size: {}", chunk.len());
if chunk.len() < chunk_size {
tracing::warn!("Chunk size is less than 8, padding with zeros");
let zeros = Array3::zeros((320, 320, 3));
let chunk: Vec<_> = chunk
.iter()
.map(|arr| arr.reborrow())
.chain(core::iter::repeat(zeros.view()))
.take(chunk_size)
.collect();
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
Ok(output)
} else {
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
Ok(output)
}
})
.collect::<Result<Vec<Array2<f32>>>>()?;
// Store embeddings in database if requested
if let (Some(database), Some(face_ids)) = (&db, &face_ids) {
let embedding_ids = database
.store_embeddings(face_ids, &embeddings, &detect.model_name)
.change_context(Error)?;
tracing::info!("Stored {} embeddings in database", embedding_ids.len());
// Print database statistics
let (num_images, num_faces, num_landmarks, num_embeddings) =
database.get_stats().change_context(Error)?;
tracing::info!(
"Database stats - Images: {}, Faces: {}, Landmarks: {}, Embeddings: {}",
num_images,
num_faces,
num_landmarks,
num_embeddings
);
}
let v = array.view();
if let Some(output) = detect.output {
let image: image::RgbImage = v
.to_image()
.change_context(errors::Error)
.attach_printable("Failed to convert ndarray to image")?;
image
.save(output)
.change_context(errors::Error)
.attach_printable("Failed to save output image")?;
}
Ok(())
}
fn run_query(query: cli::Query) -> Result<()> {
let db = FaceDatabase::new(&query.database).change_context(Error)?;
if let Some(image_id) = query.image_id {
if let Some(image) = db.get_image(image_id).change_context(Error)? {
println!("Image: {}", image.file_path);
println!("Dimensions: {}x{}", image.width, image.height);
println!("Created: {}", image.created_at);
let faces = db.get_faces_for_image(image_id).change_context(Error)?;
println!("Faces found: {}", faces.len());
for face in faces {
println!(
" Face ID {}: bbox({:.1}, {:.1}, {:.1}, {:.1}), confidence: {:.3}",
face.id,
face.bbox_x1,
face.bbox_y1,
face.bbox_x2,
face.bbox_y2,
face.confidence
);
if query.show_landmarks {
if let Some(landmarks) = db.get_landmarks(face.id).change_context(Error)? {
println!(
" Landmarks: left_eye({:.1}, {:.1}), right_eye({:.1}, {:.1}), nose({:.1}, {:.1})",
landmarks.left_eye_x,
landmarks.left_eye_y,
landmarks.right_eye_x,
landmarks.right_eye_y,
landmarks.nose_x,
landmarks.nose_y
);
}
}
if query.show_embeddings {
let embeddings = db.get_embeddings(face.id).change_context(Error)?;
for embedding in embeddings {
println!(
" Embedding ({}): {} dims, model: {}",
embedding.id,
embedding.embedding.len(),
embedding.model_name
);
}
}
}
} else {
println!("Image with ID {} not found", image_id);
}
}
if let Some(face_id) = query.face_id {
if let Some(landmarks) = db.get_landmarks(face_id).change_context(Error)? {
println!(
"Landmarks for face {}: left_eye({:.1}, {:.1}), right_eye({:.1}, {:.1}), nose({:.1}, {:.1})",
face_id,
landmarks.left_eye_x,
landmarks.left_eye_y,
landmarks.right_eye_x,
landmarks.right_eye_y,
landmarks.nose_x,
landmarks.nose_y
);
} else {
println!("No landmarks found for face {}", face_id);
}
let embeddings = db.get_embeddings(face_id).change_context(Error)?;
println!(
"Embeddings for face {}: {} found",
face_id,
embeddings.len()
);
for embedding in embeddings {
println!(
" Embedding {}: {} dims, model: {}, created: {}",
embedding.id,
embedding.embedding.len(),
embedding.model_name,
embedding.created_at
);
// if query.show_embeddings {
// println!(" Values: {:?}", &embedding.embedding);
// }
}
}
Ok(())
}
fn run_compare<D, E>(compare: cli::Compare, mut retinaface: D, mut facenet: E) -> Result<()>
where
D: facedet::FaceDetector,
E: faceembed::FaceEmbedder,
{
// Helper function to detect faces and compute embeddings for an image
fn process_image<D, E>(
image_path: &std::path::Path,
retinaface: &mut D,
facenet: &mut E,
config: &FaceDetectionConfig,
batch_size: usize,
) -> Result<(Vec<Array1<f32>>, usize)>
where
D: facedet::FaceDetector,
E: faceembed::FaceEmbedder,
{
let image = image::open(image_path)
.change_context(Error)
.attach_printable(image_path.to_string_lossy().to_string())?;
let image = image.into_rgb8();
let array = image
.into_ndarray()
.change_context(errors::Error)
.attach_printable("Failed to convert image to ndarray")?;
let output = retinaface
.detect_faces(array.view(), config)
.change_context(errors::Error)
.attach_printable("Failed to detect faces")?;
tracing::info!(
"Detected {} faces in {}",
output.bbox.len(),
image_path.display()
);
if output.bbox.is_empty() {
return Ok((Vec::new(), 0));
}
let face_rois = array
.view()
.multi_roi(&output.bbox)
.change_context(Error)?
.into_iter()
.map(|roi| {
roi.as_standard_layout()
.fast_resize(320, 320, &ResizeOptions::default())
.change_context(Error)
})
.collect::<Result<Vec<_>>>()?;
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
let chunk_size = batch_size;
let embeddings = face_roi_views
.chunks(chunk_size)
.map(|chunk| {
if chunk.len() < chunk_size {
let zeros = Array3::zeros((320, 320, 3));
let chunk: Vec<_> = chunk
.iter()
.map(|arr| arr.reborrow())
.chain(core::iter::repeat(zeros.view()))
.take(chunk_size)
.collect();
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
facenet.run_models(face_rois.view()).change_context(Error)
} else {
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
facenet.run_models(face_rois.view()).change_context(Error)
}
})
.collect::<Result<Vec<Array2<f32>>>>()?;
// Flatten embeddings into individual face embeddings
let mut face_embeddings = Vec::new();
for embedding_batch in embeddings {
for i in 0..output.bbox.len().min(embedding_batch.nrows()) {
face_embeddings.push(embedding_batch.row(i).to_owned());
}
}
Ok((face_embeddings, output.bbox.len()))
}
// Helper function to compute cosine similarity between two embeddings
fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
let dot_product = a.dot(b);
let norm_a = a.dot(a).sqrt();
let norm_b = b.dot(b).sqrt();
dot_product / (norm_a * norm_b)
}
let config = FaceDetectionConfig::default()
.with_threshold(compare.threshold)
.with_nms_threshold(compare.nms_threshold);
// Process both images
let (embeddings1, face_count1) = process_image(
&compare.image1,
&mut retinaface,
&mut facenet,
&config,
compare.batch_size,
)?;
let (embeddings2, face_count2) = process_image(
&compare.image2,
&mut retinaface,
&mut facenet,
&config,
compare.batch_size,
)?;
println!(
"Image 1 ({}): {} faces detected",
compare.image1.display(),
face_count1
);
println!(
"Image 2 ({}): {} faces detected",
compare.image2.display(),
face_count2
);
if embeddings1.is_empty() && embeddings2.is_empty() {
println!("No faces detected in either image");
return Ok(());
}
if embeddings1.is_empty() {
println!("No faces detected in image 1");
return Ok(());
}
if embeddings2.is_empty() {
println!("No faces detected in image 2");
return Ok(());
}
// Compare all faces between the two images
println!("\nFace comparison results:");
println!("========================");
let mut max_similarity = f32::NEG_INFINITY;
let mut best_match = (0, 0);
for (i, emb1) in embeddings1.iter().enumerate() {
for (j, emb2) in embeddings2.iter().enumerate() {
let similarity = cosine_similarity(emb1, emb2);
println!(
"Face {} (image 1) vs Face {} (image 2): {:.4}",
i + 1,
j + 1,
similarity
);
if similarity > max_similarity {
max_similarity = similarity;
best_match = (i + 1, j + 1);
}
}
}
println!(
"\nBest match: Face {} (image 1) vs Face {} (image 2) with similarity: {:.4}",
best_match.0, best_match.1, max_similarity
);
// Interpretation of similarity score
if max_similarity > 0.8 {
println!("Interpretation: Very likely the same person");
} else if max_similarity > 0.6 {
println!("Interpretation: Possibly the same person");
} else if max_similarity > 0.4 {
println!("Interpretation: Unlikely to be the same person");
} else {
println!("Interpretation: Very unlikely to be the same person");
}
Ok(())
}
fn run_multi_detection<D, E>(
detect_multi: cli::DetectMulti,
mut retinaface: D,
mut facenet: E,
) -> Result<()>
where
D: facedet::FaceDetector,
E: faceembed::FaceEmbedder,
{
use std::fs;
// Initialize database - always save to database for multi-detection
let db = FaceDatabase::new(&detect_multi.database).change_context(Error)?;
// Parse supported extensions
let extensions: std::collections::HashSet<String> = detect_multi
.extensions
.split(',')
.map(|ext| ext.trim().to_lowercase())
.collect();
// Create output directory if specified
if let Some(ref output_dir) = detect_multi.output_dir {
fs::create_dir_all(output_dir)
.change_context(Error)
.attach_printable("Failed to create output directory")?;
}
// Read directory and filter image files
let entries = fs::read_dir(&detect_multi.input_dir)
.change_context(Error)
.attach_printable("Failed to read input directory")?;
let mut image_paths = Vec::new();
for entry in entries {
let entry = entry.change_context(Error)?;
let path = entry.path();
if path.is_file() {
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
if extensions.contains(&ext.to_lowercase()) {
image_paths.push(path);
}
}
}
}
if image_paths.is_empty() {
tracing::warn!(
"No image files found in directory: {:?}",
detect_multi.input_dir
);
return Ok(());
}
tracing::info!("Found {} image files to process", image_paths.len());
let mut total_faces = 0;
let mut processed_images = 0;
// Process each image
for (idx, image_path) in image_paths.iter().enumerate() {
tracing::info!(
"Processing image {}/{}: {:?}",
idx + 1,
image_paths.len(),
image_path
);
// Load and process image
let image = match image::open(image_path) {
Ok(img) => img.into_rgb8(),
Err(e) => {
tracing::error!("Failed to load image {:?}: {}", image_path, e);
continue;
}
};
let (image_width, image_height) = image.dimensions();
let mut array = match image.into_ndarray().change_context(errors::Error) {
Ok(arr) => arr,
Err(e) => {
tracing::error!("Failed to convert image to ndarray: {:?}", e);
continue;
}
};
let config = FaceDetectionConfig::default()
.with_threshold(detect_multi.threshold)
.with_nms_threshold(detect_multi.nms_threshold);
// Detect faces
let output = match retinaface.detect_faces(array.view(), &config) {
Ok(output) => output,
Err(e) => {
tracing::error!("Failed to detect faces in {:?}: {:?}", image_path, e);
continue;
}
};
let num_faces = output.bbox.len();
total_faces += num_faces;
if num_faces == 0 {
tracing::info!("No faces detected in {:?}", image_path);
} else {
tracing::info!("Detected {} faces in {:?}", num_faces, image_path);
}
// Store image and detections in database
let image_path_str = image_path.to_string_lossy();
let img_id = match db.store_image(&image_path_str, image_width, image_height) {
Ok(id) => id,
Err(e) => {
tracing::error!("Failed to store image in database: {:?}", e);
continue;
}
};
let face_ids = match db.store_face_detections(img_id, &output) {
Ok(ids) => ids,
Err(e) => {
tracing::error!("Failed to store face detections in database: {:?}", e);
continue;
}
};
// Draw bounding boxes if output directory is specified
if detect_multi.output_dir.is_some() {
for bbox in &output.bbox {
use bounding_box::draw::*;
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
}
}
// Process face embeddings if faces were detected
if !face_ids.is_empty() {
let face_rois = match array.view().multi_roi(&output.bbox).change_context(Error) {
Ok(rois) => rois,
Err(e) => {
tracing::error!("Failed to extract face ROIs: {:?}", e);
continue;
}
};
let face_rois: Result<Vec<_>> = face_rois
.into_iter()
.map(|roi| {
roi.as_standard_layout()
.fast_resize(320, 320, &ResizeOptions::default())
.change_context(Error)
})
.collect();
let face_rois = match face_rois {
Ok(rois) => rois,
Err(e) => {
tracing::error!("Failed to resize face ROIs: {:?}", e);
continue;
}
};
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
let chunk_size = detect_multi.batch_size;
let embeddings: Result<Vec<Array2<f32>>> = face_roi_views
.chunks(chunk_size)
.map(|chunk| {
if chunk.len() < chunk_size {
let zeros = Array3::zeros((320, 320, 3));
let chunk: Vec<_> = chunk
.iter()
.map(|arr| arr.reborrow())
.chain(core::iter::repeat(zeros.view()))
.take(chunk_size)
.collect();
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk.as_slice())
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
facenet.run_models(face_rois.view()).change_context(Error)
} else {
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
facenet.run_models(face_rois.view()).change_context(Error)
}
})
.collect();
let embeddings = match embeddings {
Ok(emb) => emb,
Err(e) => {
tracing::error!("Failed to generate embeddings: {:?}", e);
continue;
}
};
// Store embeddings in database
if let Err(e) = db.store_embeddings(&face_ids, &embeddings, &detect_multi.model_name) {
tracing::error!("Failed to store embeddings in database: {:?}", e);
continue;
}
}
// Save output image if directory specified
if let Some(ref output_dir) = detect_multi.output_dir {
let output_filename = format!(
"detected_{}",
image_path.file_name().unwrap().to_string_lossy()
);
let output_path = output_dir.join(output_filename);
let v = array.view();
let output_image: image::RgbImage = match v.to_image().change_context(errors::Error) {
Ok(img) => img,
Err(e) => {
tracing::error!("Failed to convert ndarray to image: {:?}", e);
continue;
}
};
if let Err(e) = output_image.save(&output_path) {
tracing::error!("Failed to save output image to {:?}: {}", output_path, e);
continue;
}
tracing::info!("Saved output image to {:?}", output_path);
}
processed_images += 1;
}
// Print final statistics
tracing::info!(
"Processing complete: {}/{} images processed successfully, {} total faces detected",
processed_images,
image_paths.len(),
total_faces
);
let (num_images, num_faces, num_landmarks, num_embeddings) =
db.get_stats().change_context(Error)?;
tracing::info!(
"Database stats - Images: {}, Faces: {}, Landmarks: {}, Embeddings: {}",
num_images,
num_faces,
num_landmarks,
num_embeddings
);
Ok(())
}
fn run_similar(similar: cli::Similar) -> Result<()> {
let db = FaceDatabase::new(&similar.database).change_context(Error)?;
let embeddings = db.get_embeddings(similar.face_id).change_context(Error)?;
if embeddings.is_empty() {
println!("No embeddings found for face {}", similar.face_id);
return Ok(());
}
let query_embedding = &embeddings[0].embedding;
let similar_faces = db
.find_similar_faces(query_embedding, similar.threshold, similar.limit)
.change_context(Error)?;
// Get image information for the similar faces
println!(
"Found {} similar faces (threshold: {:.3}):",
similar_faces.len(),
similar.threshold
);
for (face_id, similarity) in &similar_faces {
if let Some(image_info) = db.get_image_for_face(*face_id).change_context(Error)? {
println!(
" Face {}: similarity {:.3}, image: {}",
face_id, similarity, image_info.file_path
);
}
}
Ok(())
}
fn run_stats(stats: cli::Stats) -> Result<()> {
let db = FaceDatabase::new(&stats.database).change_context(Error)?;
let (images, faces, landmarks, embeddings) = db.get_stats().change_context(Error)?;
println!("Database Statistics:");
println!(" Images: {}", images);
println!(" Faces: {}", faces);
println!(" Landmarks: {}", landmarks);
println!(" Embeddings: {}", embeddings);
Ok(())
}

17
src/bin/gui.rs Normal file
View File

@@ -0,0 +1,17 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Initialize logging
tracing_subscriber::fmt()
.with_env_filter("info")
.with_thread_ids(true)
.with_thread_names(true)
.with_target(false)
.init();
// Run the GUI
if let Err(e) = detector::gui::run() {
eprintln!("GUI error: {}", e);
std::process::exit(1);
}
Ok(())
}

View File

@@ -1,71 +0,0 @@
use std::path::PathBuf;
#[derive(Debug, clap::Parser)]
pub struct Cli {
#[clap(subcommand)]
pub cmd: SubCommand,
}
#[derive(Debug, clap::Subcommand)]
pub enum SubCommand {
#[clap(name = "detect")]
Detect(Detect),
#[clap(name = "list")]
List(List),
#[clap(name = "completions")]
Completions { shell: clap_complete::Shell },
}
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
pub enum Models {
RetinaFace,
Yolo,
}
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
pub enum Executor {
Mnn,
Onnx,
}
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
pub enum OnnxEp {
Cpu,
}
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
pub enum MnnEp {
Cpu,
Metal,
OpenCL,
CoreML,
}
#[derive(Debug, clap::Args)]
pub struct Detect {
#[clap(short, long)]
pub model: Option<PathBuf>,
#[clap(short = 'M', long, default_value = "retina-face")]
pub model_type: Models,
#[clap(short, long)]
pub output: Option<PathBuf>,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)]
pub nms_threshold: f32,
pub image: PathBuf,
}
#[derive(Debug, clap::Args)]
pub struct List {}
impl Cli {
pub fn completions(shell: clap_complete::Shell) {
let mut command = <Cli as clap::CommandFactory>::command();
clap_complete::generate(
shell,
&mut command,
env!("CARGO_BIN_NAME"),
&mut std::io::stdout(),
);
}
}

663
src/database.rs Normal file
View File

@@ -0,0 +1,663 @@
use crate::errors::{Error, Result};
use crate::facedet::{FaceDetectionOutput, FaceLandmarks};
use bounding_box::Aabb2;
use error_stack::ResultExt;
use ndarray_math::CosineSimilarity;
use rusqlite::{Connection, OptionalExtension, params};
use std::path::Path;
/// Database connection and operations for face detection results
pub struct FaceDatabase {
conn: Connection,
}
/// Represents a stored image record
#[derive(Debug, Clone)]
pub struct ImageRecord {
pub id: i64,
pub file_path: String,
pub width: u32,
pub height: u32,
pub created_at: String,
}
/// Represents a stored face detection record
#[derive(Debug, Clone)]
pub struct FaceRecord {
pub id: i64,
pub image_id: i64,
pub bbox_x1: f32,
pub bbox_y1: f32,
pub bbox_x2: f32,
pub bbox_y2: f32,
pub confidence: f32,
pub created_at: String,
}
/// Represents stored face landmarks
#[derive(Debug, Clone)]
pub struct LandmarkRecord {
pub id: i64,
pub face_id: i64,
pub left_eye_x: f32,
pub left_eye_y: f32,
pub right_eye_x: f32,
pub right_eye_y: f32,
pub nose_x: f32,
pub nose_y: f32,
pub left_mouth_x: f32,
pub left_mouth_y: f32,
pub right_mouth_x: f32,
pub right_mouth_y: f32,
}
/// Represents a stored face embedding
#[derive(Debug, Clone)]
pub struct EmbeddingRecord {
pub id: i64,
pub face_id: i64,
pub embedding: ndarray::Array1<f32>,
pub model_name: String,
pub created_at: String,
}
impl FaceDatabase {
/// Create a new database connection and initialize tables
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
let conn = Connection::open(db_path).change_context(Error)?;
unsafe {
let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
conn.load_extension(
"/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
None::<&str>,
)
.change_context(Error)?;
}
let db = Self { conn };
db.create_tables()?;
Ok(db)
}
/// Create an in-memory database for testing
pub fn in_memory() -> Result<Self> {
let conn = Connection::open_in_memory().change_context(Error)?;
let db = Self { conn };
db.create_tables()?;
Ok(db)
}
/// Create all necessary database tables
fn create_tables(&self) -> Result<()> {
// Images table
self.conn
.execute(
r#"
CREATE TABLE IF NOT EXISTS images (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_path TEXT NOT NULL UNIQUE,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"#,
[],
)
.change_context(Error)?;
// Faces table
self.conn
.execute(
r#"
CREATE TABLE IF NOT EXISTS faces (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_id INTEGER NOT NULL,
bbox_x1 REAL NOT NULL,
bbox_y1 REAL NOT NULL,
bbox_x2 REAL NOT NULL,
bbox_y2 REAL NOT NULL,
confidence REAL NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
)
"#,
[],
)
.change_context(Error)?;
// Landmarks table
self.conn
.execute(
r#"
CREATE TABLE IF NOT EXISTS landmarks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
face_id INTEGER NOT NULL,
left_eye_x REAL NOT NULL,
left_eye_y REAL NOT NULL,
right_eye_x REAL NOT NULL,
right_eye_y REAL NOT NULL,
nose_x REAL NOT NULL,
nose_y REAL NOT NULL,
left_mouth_x REAL NOT NULL,
left_mouth_y REAL NOT NULL,
right_mouth_x REAL NOT NULL,
right_mouth_y REAL NOT NULL,
FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE
)
"#,
[],
)
.change_context(Error)?;
// Embeddings table
self.conn
.execute(
r#"
CREATE TABLE IF NOT EXISTS embeddings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
face_id INTEGER NOT NULL,
embedding BLOB NOT NULL,
model_name TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE
)
"#,
[],
)
.change_context(Error)?;
// Create indexes for better performance
self.conn
.execute(
"CREATE INDEX IF NOT EXISTS idx_faces_image_id ON faces (image_id)",
[],
)
.change_context(Error)?;
self.conn
.execute(
"CREATE INDEX IF NOT EXISTS idx_landmarks_face_id ON landmarks (face_id)",
[],
)
.change_context(Error)?;
self.conn
.execute(
"CREATE INDEX IF NOT EXISTS idx_embeddings_face_id ON embeddings (face_id)",
[],
)
.change_context(Error)?;
Ok(())
}
/// Store image metadata and return the image ID
pub fn store_image(&self, file_path: &str, width: u32, height: u32) -> Result<i64> {
let mut stmt = self
.conn
.prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)")
.change_context(Error)?;
Ok(stmt
.insert(params![file_path, width, height])
.change_context(Error)?)
}
/// Store face detection results
pub fn store_face_detections(
&self,
image_id: i64,
detection_output: &FaceDetectionOutput,
) -> Result<Vec<i64>> {
let mut face_ids = Vec::new();
for (i, bbox) in detection_output.bbox.iter().enumerate() {
let confidence = detection_output.confidence.get(i).copied().unwrap_or(0.0);
let face_id = self.store_face(image_id, bbox, confidence)?;
face_ids.push(face_id);
// Store landmarks if available
if let Some(landmarks) = detection_output.landmark.get(i) {
self.store_landmarks(face_id, landmarks)?;
}
}
Ok(face_ids)
}
/// Store a single face detection
pub fn store_face(&self, image_id: i64, bbox: &Aabb2<usize>, confidence: f32) -> Result<i64> {
let mut stmt = self
.conn
.prepare(
r#"
INSERT INTO faces (image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
"#,
)
.change_context(Error)?;
Ok(stmt
.insert(params![
image_id,
bbox.x1() as f32,
bbox.y1() as f32,
bbox.x2() as f32,
bbox.y2() as f32,
confidence
])
.change_context(Error)?)
}
/// Store face landmarks
pub fn store_landmarks(&self, face_id: i64, landmarks: &FaceLandmarks) -> Result<i64> {
let mut stmt = self
.conn
.prepare(
r#"
INSERT INTO landmarks
(face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y,
nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)
"#,
)
.change_context(Error)?;
Ok(stmt
.insert(params![
face_id,
landmarks.left_eye.x,
landmarks.left_eye.y,
landmarks.right_eye.x,
landmarks.right_eye.y,
landmarks.nose.x,
landmarks.nose.y,
landmarks.left_mouth.x,
landmarks.left_mouth.y,
landmarks.right_mouth.x,
landmarks.right_mouth.y,
])
.change_context(Error)?)
}
/// Store face embeddings
pub fn store_embeddings(
&self,
face_ids: &[i64],
embeddings: &[ndarray::Array2<f32>],
model_name: &str,
) -> Result<Vec<i64>> {
let mut embedding_ids = Vec::new();
for (face_idx, embedding_batch) in embeddings.iter().enumerate() {
for (batch_idx, embedding_row) in embedding_batch.rows().into_iter().enumerate() {
let global_idx = face_idx * embedding_batch.nrows() + batch_idx;
if global_idx >= face_ids.len() {
break;
}
let face_id = face_ids[global_idx];
let embedding_id =
self.store_single_embedding(face_id, embedding_row, model_name)?;
embedding_ids.push(embedding_id);
}
}
Ok(embedding_ids)
}
/// Store a single embedding
pub fn store_single_embedding(
&self,
face_id: i64,
embedding: ndarray::ArrayView1<f32>,
model_name: &str,
) -> Result<i64> {
let safe_arrays =
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)])
.change_context(Error)?;
let embedding_bytes = safe_arrays.serialize().change_context(Error)?;
let mut stmt = self
.conn
.prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)")
.change_context(Error)?;
stmt.execute(params![face_id, embedding_bytes, model_name])
.change_context(Error)?;
Ok(self.conn.last_insert_rowid())
}
/// Get image by ID
pub fn get_image(&self, image_id: i64) -> Result<Option<ImageRecord>> {
let mut stmt = self
.conn
.prepare("SELECT id, file_path, width, height, created_at FROM images WHERE id = ?1")
.change_context(Error)?;
let result = stmt
.query_row(params![image_id], |row| {
Ok(ImageRecord {
id: row.get(0)?,
file_path: row.get(1)?,
width: row.get(2)?,
height: row.get(3)?,
created_at: row.get(4)?,
})
})
.optional()
.change_context(Error)?;
Ok(result)
}
/// Get all faces for an image
pub fn get_faces_for_image(&self, image_id: i64) -> Result<Vec<FaceRecord>> {
let mut stmt = self
.conn
.prepare(
r#"
SELECT id, image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence, created_at
FROM faces WHERE image_id = ?1
"#,
)
.change_context(Error)?;
let face_iter = stmt
.query_map(params![image_id], |row| {
Ok(FaceRecord {
id: row.get(0)?,
image_id: row.get(1)?,
bbox_x1: row.get(2)?,
bbox_y1: row.get(3)?,
bbox_x2: row.get(4)?,
bbox_y2: row.get(5)?,
confidence: row.get(6)?,
created_at: row.get(7)?,
})
})
.change_context(Error)?;
let mut faces = Vec::new();
for face in face_iter {
faces.push(face.change_context(Error)?);
}
Ok(faces)
}
/// Get landmarks for a face
pub fn get_landmarks(&self, face_id: i64) -> Result<Option<LandmarkRecord>> {
let mut stmt = self
.conn
.prepare(
r#"
SELECT id, face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y,
nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y
FROM landmarks WHERE face_id = ?1
"#,
)
.change_context(Error)?;
let result = stmt
.query_row(params![face_id], |row| {
Ok(LandmarkRecord {
id: row.get(0)?,
face_id: row.get(1)?,
left_eye_x: row.get(2)?,
left_eye_y: row.get(3)?,
right_eye_x: row.get(4)?,
right_eye_y: row.get(5)?,
nose_x: row.get(6)?,
nose_y: row.get(7)?,
left_mouth_x: row.get(8)?,
left_mouth_y: row.get(9)?,
right_mouth_x: row.get(10)?,
right_mouth_y: row.get(11)?,
})
})
.optional()
.change_context(Error)?;
Ok(result)
}
/// Get embeddings for a face
pub fn get_embeddings(&self, face_id: i64) -> Result<Vec<EmbeddingRecord>> {
let mut stmt = self
.conn
.prepare(
"SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE face_id = ?1",
)
.change_context(Error)?;
let embedding_iter = stmt
.query_map(params![face_id], |row| {
let embedding_bytes: Vec<u8> = row.get(2)?;
let embedding: ndarray::Array1<f32> = {
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
.change_context(Error)
// .change_context(Error)?
.unwrap();
sf.tensor::<f32, ndarray::Ix1>("embedding")
// .change_context(Error)?
.unwrap()
.to_owned()
};
Ok(EmbeddingRecord {
id: row.get(0)?,
face_id: row.get(1)?,
embedding,
model_name: row.get(3)?,
created_at: row.get(4)?,
})
})
.change_context(Error)?;
let mut embeddings = Vec::new();
for embedding in embedding_iter {
embeddings.push(embedding.change_context(Error)?);
}
Ok(embeddings)
}
pub fn get_image_for_face(&self, face_id: i64) -> Result<Option<ImageRecord>> {
let mut stmt = self
.conn
.prepare(
r#"
SELECT images.id, images.file_path, images.width, images.height, images.created_at
FROM images
JOIN faces ON faces.image_id = images.id
WHERE faces.id = ?1
"#,
)
.change_context(Error)?;
let result = stmt
.query_row(params![face_id], |row| {
Ok(ImageRecord {
id: row.get(0)?,
file_path: row.get(1)?,
width: row.get(2)?,
height: row.get(3)?,
created_at: row.get(4)?,
})
})
.optional()
.change_context(Error)?;
Ok(result)
}
/// Get database statistics
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
let images: usize = self
.conn
.query_row("SELECT COUNT(*) FROM images", [], |row| row.get(0))
.change_context(Error)?;
let faces: usize = self
.conn
.query_row("SELECT COUNT(*) FROM faces", [], |row| row.get(0))
.change_context(Error)?;
let landmarks: usize = self
.conn
.query_row("SELECT COUNT(*) FROM landmarks", [], |row| row.get(0))
.change_context(Error)?;
let embeddings: usize = self
.conn
.query_row("SELECT COUNT(*) FROM embeddings", [], |row| row.get(0))
.change_context(Error)?;
Ok((images, faces, landmarks, embeddings))
}
/// Find similar faces based on cosine similarity of embeddings
/// Return ids and similarity scores of similar faces
pub fn find_similar_faces(
&self,
embedding: &ndarray::Array1<f32>,
threshold: f32,
limit: usize,
) -> Result<Vec<(i64, f32)>> {
// Serialize the query embedding to bytes
let embedding_bytes =
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
.change_context(Error)?
.serialize()
.change_context(Error)?;
let mut stmt = self
.conn
.prepare(
r#" SELECT face_id, cosine_similarity(?1, embedding) as similarity
FROM embeddings
WHERE cosine_similarity(?1, embedding) >= ?2
ORDER BY similarity DESC
LIMIT ?3"#,
)
.change_context(Error)?;
let result = stmt
.query_map(params![embedding_bytes, threshold, limit], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
})
.change_context(Error)?
.map(|r| r.change_context(Error))
.collect::<Result<Vec<_>>>()?;
// let mut results = Vec::new();
// for result in result_iter {
// results.push(result.change_context(Error)?);
// }
Ok(result)
}
pub fn query_similarity(&self, embedding: &ndarray::Array1<f32>) {
let embedding_bytes =
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
.change_context(Error)
.unwrap()
.serialize()
.change_context(Error)
.unwrap();
let mut stmt = self
.conn
.prepare(
r#"
SELECT face_id,
cosine_similarity(?1, embedding)
FROM embeddings
"#,
)
.change_context(Error)
.unwrap();
let result_iter = stmt
.query_map(params![embedding_bytes], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
})
.change_context(Error)
.unwrap();
for result in result_iter {
println!("{:?}", result);
}
}
}
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
use rusqlite::functions::*;
db.create_scalar_function(
"cosine_similarity",
2,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
move |ctx| {
if ctx.len() != 2 {
return Err(rusqlite::Error::UserFunctionError(
"cosine_similarity requires exactly 2 arguments".into(),
));
}
let array_1 = ctx.get_raw(0).as_blob()?;
let array_2 = ctx.get_raw(1).as_blob()?;
let array_1_st = ndarray_safetensors::SafeArraysView::from_bytes(array_1)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let array_2_st = ndarray_safetensors::SafeArraysView::from_bytes(array_2)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let array_view_1 = array_1_st
.tensor_by_index::<f32, ndarray::Ix1>(0)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let array_view_2 = array_2_st
.tensor_by_index::<f32, ndarray::Ix1>(0)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let similarity = array_view_1
.cosine_similarity(array_view_2)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
Ok(similarity)
},
)
.change_context(Error)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_database_creation() -> Result<()> {
let db = FaceDatabase::in_memory()?;
let (images, faces, landmarks, embeddings) = db.get_stats()?;
assert_eq!(images, 0);
assert_eq!(faces, 0);
assert_eq!(landmarks, 0);
assert_eq!(embeddings, 0);
Ok(())
}
#[test]
fn test_store_and_retrieve_image() -> Result<()> {
let db = FaceDatabase::in_memory()?;
let image_id = db.store_image("/path/to/image.jpg", 800, 600)?;
let image = db.get_image(image_id)?.unwrap();
assert_eq!(image.file_path, "/path/to/image.jpg");
assert_eq!(image.width, 800);
assert_eq!(image.height, 600);
Ok(())
}
}

View File

@@ -1,2 +1,8 @@
pub mod retinaface;
pub mod yolo;
// Re-export common types and traits
pub use retinaface::{
FaceDetectionConfig, FaceDetectionModelOutput, FaceDetectionOutput,
FaceDetectionProcessedOutput, FaceDetector, FaceLandmarks,
};

View File

@@ -1,67 +1,88 @@
pub mod mnn;
pub mod ort;
use crate::errors::*;
use bounding_box::{Aabb2, nms::nms};
use error_stack::ResultExt;
use mnn_bridge::ndarray::*;
use nalgebra::{Point2, Vector2};
use ndarray_resize::NdFir;
use std::path::Path;
/// Configuration for face detection postprocessing
#[derive(Debug, Clone, PartialEq)]
pub struct FaceDetectionConfig {
anchor_sizes: Vec<Vector2<usize>>,
steps: Vec<usize>,
variance: Vec<f32>,
threshold: f32,
nms_threshold: f32,
/// Minimum confidence to keep a detection
pub threshold: f32,
/// NMS threshold for suppressing overlapping boxes
pub nms_threshold: f32,
/// Variances for bounding box decoding
pub variances: [f32; 2],
/// The step size (stride) for each feature map
pub steps: Vec<usize>,
/// The minimum anchor sizes for each feature map
pub min_sizes: Vec<Vec<usize>>,
/// Whether to clip bounding boxes to the image dimensions
pub clamp: bool,
/// Input image width (used for anchor generation)
pub input_width: usize,
/// Input image height (used for anchor generation)
pub input_height: usize,
}
impl FaceDetectionConfig {
pub fn with_min_sizes(mut self, min_sizes: Vec<Vector2<usize>>) -> Self {
self.anchor_sizes = min_sizes;
self
}
pub fn with_steps(mut self, steps: Vec<usize>) -> Self {
self.steps = steps;
self
}
pub fn with_variance(mut self, variance: Vec<f32>) -> Self {
self.variance = variance;
self
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = threshold;
self
}
pub fn with_nms_threshold(mut self, nms_threshold: f32) -> Self {
self.nms_threshold = nms_threshold;
self
}
pub fn with_variances(mut self, variances: [f32; 2]) -> Self {
self.variances = variances;
self
}
pub fn with_steps(mut self, steps: Vec<usize>) -> Self {
self.steps = steps;
self
}
pub fn with_min_sizes(mut self, min_sizes: Vec<Vec<usize>>) -> Self {
self.min_sizes = min_sizes;
self
}
pub fn with_clip(mut self, clip: bool) -> Self {
self.clamp = clip;
self
}
pub fn with_input_width(mut self, input_width: usize) -> Self {
self.input_width = input_width;
self
}
pub fn with_input_height(mut self, input_height: usize) -> Self {
self.input_height = input_height;
self
}
}
impl Default for FaceDetectionConfig {
fn default() -> Self {
FaceDetectionConfig {
anchor_sizes: vec![
Vector2::new(16, 32),
Vector2::new(64, 128),
Vector2::new(256, 512),
],
steps: vec![8, 16, 32],
variance: vec![0.1, 0.2],
threshold: 0.8,
Self {
threshold: 0.5,
nms_threshold: 0.4,
variances: [0.1, 0.2],
steps: vec![8, 16, 32],
min_sizes: vec![vec![16, 32], vec![64, 128], vec![256, 512]],
clamp: true,
input_width: 1024,
input_height: 1024,
}
}
}
pub struct FaceDetection {
handle: mnn_sync::SessionHandle,
}
#[derive(Debug, Clone, PartialEq)]
pub struct FaceDetectionModelOutput {
pub bbox: ndarray::Array3<f32>,
pub confidence: ndarray::Array3<f32>,
pub landmark: ndarray::Array3<f32>,
}
/// Represents the 5 facial landmarks detected by RetinaFace
#[derive(Debug, Copy, Clone, PartialEq)]
@@ -73,6 +94,13 @@ pub struct FaceLandmarks {
pub right_mouth: Point2<f32>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct FaceDetectionModelOutput {
pub bbox: ndarray::Array3<f32>,
pub confidence: ndarray::Array3<f32>,
pub landmark: ndarray::Array3<f32>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct FaceDetectionProcessedOutput {
pub bbox: Vec<Aabb2<f32>>,
@@ -87,85 +115,134 @@ pub struct FaceDetectionOutput {
pub landmark: Vec<FaceLandmarks>,
}
/// Raw model outputs that can be converted to FaceDetectionModelOutput
pub trait IntoModelOutput {
fn into_model_output(self) -> Result<FaceDetectionModelOutput>;
}
/// Generate anchors for RetinaFace model
pub fn generate_anchors(config: &FaceDetectionConfig) -> ndarray::Array2<f32> {
let mut anchors = Vec::new();
let feature_maps: Vec<(usize, usize)> = config
.steps
.iter()
.map(|&step| {
(
(config.input_height as f32 / step as f32).ceil() as usize,
(config.input_width as f32 / step as f32).ceil() as usize,
)
})
.collect();
for (k, f) in feature_maps.iter().enumerate() {
let min_sizes = &config.min_sizes[k];
for i in 0..f.0 {
for j in 0..f.1 {
for &min_size in min_sizes {
let s_kx = min_size as f32 / config.input_width as f32;
let s_ky = min_size as f32 / config.input_height as f32;
let dense_cx =
(j as f32 + 0.5) * config.steps[k] as f32 / config.input_width as f32;
let dense_cy =
(i as f32 + 0.5) * config.steps[k] as f32 / config.input_height as f32;
anchors.push([
dense_cx - s_kx / 2.,
dense_cy - s_ky / 2.,
dense_cx + s_kx / 2.,
dense_cy + s_ky / 2.,
]);
}
}
}
}
ndarray::Array2::from_shape_vec((anchors.len(), 4), anchors.into_iter().flatten().collect())
.unwrap()
}
impl FaceDetectionModelOutput {
pub fn postprocess(self, config: &FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> {
let mut anchors = Vec::new();
for (k, &step) in config.steps.iter().enumerate() {
let feature_size = 1024 / step;
let min_sizes = config.anchor_sizes[k];
let sizes = [min_sizes.x, min_sizes.y];
for i in 0..feature_size {
for j in 0..feature_size {
for &size in &sizes {
let cx = (j as f32 + 0.5) * step as f32 / 1024.0;
let cy = (i as f32 + 0.5) * step as f32 / 1024.0;
let s_k = size as f32 / 1024.0;
anchors.push((cx, cy, s_k, s_k));
}
}
}
}
let mut boxes = Vec::new();
let mut scores = Vec::new();
let mut landmarks = Vec::new();
let var0 = config.variance[0];
let var1 = config.variance[1];
let bbox_data = self.bbox;
let conf_data = self.confidence;
let landmark_data = self.landmark;
let num_priors = bbox_data.shape()[1];
for idx in 0..num_priors {
let dx = bbox_data[[0, idx, 0]];
let dy = bbox_data[[0, idx, 1]];
let dw = bbox_data[[0, idx, 2]];
let dh = bbox_data[[0, idx, 3]];
let (anchor_cx, anchor_cy, anchor_w, anchor_h) = anchors[idx];
let pred_cx = anchor_cx + dx * var0 * anchor_w;
let pred_cy = anchor_cy + dy * var0 * anchor_h;
let pred_w = anchor_w * (dw * var1).exp();
let pred_h = anchor_h * (dh * var1).exp();
let x_min = pred_cx - pred_w / 2.0;
let y_min = pred_cy - pred_h / 2.0;
let x_max = pred_cx + pred_w / 2.0;
let y_max = pred_cy + pred_h / 2.0;
let score = conf_data[[0, idx, 1]];
if score > config.threshold {
boxes.push(Aabb2::from_x1y1x2y2(x_min, y_min, x_max, y_max));
scores.push(score);
use ndarray::s;
let left_eye_x = landmark_data[[0, idx, 0]] * anchor_w * var0 + anchor_cx;
let left_eye_y = landmark_data[[0, idx, 1]] * anchor_h * var0 + anchor_cy;
let priors = generate_anchors(config);
let right_eye_x = landmark_data[[0, idx, 2]] * anchor_w * var0 + anchor_cx;
let right_eye_y = landmark_data[[0, idx, 3]] * anchor_h * var0 + anchor_cy;
let scores = self.confidence.slice(s![0, .., 1]);
let boxes = self.bbox.slice(s![0, .., ..]);
let landmarks_raw = self.landmark.slice(s![0, .., ..]);
let nose_x = landmark_data[[0, idx, 4]] * anchor_w * var0 + anchor_cx;
let nose_y = landmark_data[[0, idx, 5]] * anchor_h * var0 + anchor_cy;
// let mut decoded_boxes = Vec::new();
// let mut decoded_landmarks = Vec::new();
// let mut confidences = Vec::new();
let left_mouth_x = landmark_data[[0, idx, 6]] * anchor_w * var0 + anchor_cx;
let left_mouth_y = landmark_data[[0, idx, 7]] * anchor_h * var0 + anchor_cy;
dbg!(priors.shape());
let (decoded_boxes, decoded_landmarks, confidences) = (0..priors.shape()[0])
.filter(|&i| scores[i] > config.threshold)
.map(|i| {
let prior = priors.row(i);
let loc = boxes.row(i);
let landm = landmarks_raw.row(i);
let right_mouth_x = landmark_data[[0, idx, 8]] * anchor_w * var0 + anchor_cx;
let right_mouth_y = landmark_data[[0, idx, 9]] * anchor_h * var0 + anchor_cy;
// Decode bounding box
let prior_cx = (prior[0] + prior[2]) / 2.0;
let prior_cy = (prior[1] + prior[3]) / 2.0;
let prior_w = prior[2] - prior[0];
let prior_h = prior[3] - prior[1];
landmarks.push(FaceLandmarks {
left_eye: Point2::new(left_eye_x, left_eye_y),
right_eye: Point2::new(right_eye_x, right_eye_y),
nose: Point2::new(nose_x, nose_y),
left_mouth: Point2::new(left_mouth_x, left_mouth_y),
right_mouth: Point2::new(right_mouth_x, right_mouth_y),
});
}
let var = config.variances;
let cx = prior_cx + loc[0] * var[0] * prior_w;
let cy = prior_cy + loc[1] * var[0] * prior_h;
let w = prior_w * (loc[2] * var[1]).exp();
let h = prior_h * (loc[3] * var[1]).exp();
let xmin = cx - w / 2.0;
let ymin = cy - h / 2.0;
let xmax = cx + w / 2.0;
let ymax = cy + h / 2.0;
let mut bbox =
Aabb2::from_min_max_vertices(Point2::new(xmin, ymin), Point2::new(xmax, ymax));
if config.clamp {
bbox = bbox.component_clamp(0.0, 1.0);
}
// Decode landmarks
let points: [Point2<f32>; 5] = (0..5)
.map(|j| {
Point2::new(
prior_cx + landm[j * 2] * var[0] * prior_w,
prior_cy + landm[j * 2 + 1] * var[0] * prior_h,
)
})
.collect::<Vec<_>>()
.try_into()
.unwrap();
let landmarks = FaceLandmarks {
left_eye: points[0],
right_eye: points[1],
nose: points[2],
left_mouth: points[3],
right_mouth: points[4],
};
(bbox, landmarks, scores[i])
})
.fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut boxes, mut landmarks, mut confs), (bbox, landmark, conf)| {
boxes.push(bbox);
landmarks.push(landmark);
confs.push(conf);
(boxes, landmarks, confs)
},
);
Ok(FaceDetectionProcessedOutput {
bbox: boxes,
confidence: scores,
landmarks,
bbox: decoded_boxes,
confidence: confidences,
landmarks: decoded_landmarks,
})
}
}
impl FaceDetectionModelOutput {
pub fn print(&self, limit: usize) {
tracing::info!("Detected {} faces", self.bbox.shape()[1]);
@@ -189,49 +266,16 @@ impl FaceDetectionModelOutput {
}
}
impl FaceDetection {
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
let model = std::fs::read(path)
.change_context(Error)
.attach_printable("Failed to read model file")?;
Self::new_from_bytes(&model)
}
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
tracing::info!("Loading face detection model from bytes");
let mut model = mnn::Interpreter::from_bytes(model)
.map_err(|e| e.into_inner())
.change_context(Error)
.attach_printable("Failed to load model from bytes")?;
model.set_session_mode(mnn::SessionMode::Release);
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
let sc = mnn::ScheduleConfig::new()
.with_type(mnn::ForwardType::CPU)
.with_backend_config(bc);
tracing::info!("Creating session handle for face detection model");
let handle = mnn_sync::SessionHandle::new(model, sc)
.change_context(Error)
.attach_printable("Failed to create session handle")?;
Ok(FaceDetection { handle })
}
pub fn detect_faces(
&self,
image: ndarray::Array3<u8>,
config: FaceDetectionConfig,
/// Apply Non-Maximum Suppression and convert to final output format
pub fn apply_nms_and_finalize(
processed: FaceDetectionProcessedOutput,
config: &FaceDetectionConfig,
image_size: (usize, usize), // (width, height)
) -> Result<FaceDetectionOutput> {
let (height, width, _channels) = image.dim();
let output = self
.run_models(image)
.change_context(Error)
.attach_printable("Failed to detect faces")?;
// denormalize the bounding boxes
let factor = Vector2::new(width as f32, height as f32);
let mut processed = output
.postprocess(&config)
.attach_printable("Failed to postprocess")?;
use itertools::Itertools;
let factor = Vector2::new(image_size.0 as f32, image_size.1 as f32);
let (boxes, scores, landmarks): (Vec<_>, Vec<_>, Vec<_>) = processed
.bbox
.iter()
@@ -242,7 +286,8 @@ impl FaceDetection {
.map(|((b, s), l)| (b, s, l))
.multiunzip();
let keep_indices = nms(&boxes, &scores, config.threshold, config.nms_threshold);
let keep_indices =
nms(&boxes, &scores, config.threshold, config.nms_threshold).change_context(Error)?;
let bboxes = boxes
.into_iter()
@@ -270,79 +315,27 @@ impl FaceDetection {
})
}
pub fn run_models(&self, image: ndarray::Array3<u8>) -> Result<FaceDetectionModelOutput> {
#[rustfmt::skip]
use ::tap::*;
let output = self
.handle
.run(move |sr| {
let mut resized = image
.fast_resize(1024, 1024, None)
.change_context(mnn::ErrorKind::TensorError)?
.mapv(|f| f as f32)
.tap_mut(|arr| {
arr.axis_iter_mut(ndarray::Axis(2))
.zip([104, 117, 123])
.for_each(|(mut array, pixel)| {
let pixel = pixel as f32;
array.map_inplace(|v| *v -= pixel);
});
})
.permuted_axes((2, 0, 1))
.insert_axis(ndarray::Axis(0))
.as_standard_layout()
.into_owned();
let tensor = resized
.as_mnn_tensor_mut()
.attach_printable("Failed to convert ndarray to mnn tensor")
.change_context(mnn::error::ErrorKind::TensorError)?;
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
let (intptr, session) = sr.both_mut();
tracing::trace!("Copying input tensor to host");
unsafe {
let mut input = intptr.input_unresized::<f32>(session, "input")?;
tracing::trace!("Input shape: {:?}", input.shape());
intptr.resize_tensor_by_nchw::<mnn::View<&mut f32>, _>(
input.view_mut(),
1,
3,
1024,
1024,
);
}
intptr.resize_session(session);
let mut input = intptr.input::<f32>(session, "input")?;
tracing::trace!("Input shape: {:?}", input.shape());
input.copy_from_host_tensor(tensor.view())?;
/// Common trait for face detection backends
pub trait FaceDetector {
/// Run inference on the model and return raw outputs
fn run_model(&mut self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput>;
tracing::info!("Running face detection session");
intptr.run_session(&session)?;
let output_tensor = intptr
.output::<f32>(&session, "bbox")?
.create_host_tensor_from_device(true)
.as_ndarray()
.to_owned();
tracing::trace!("Output Bbox: \t\t{:?}", output_tensor.shape());
let output_confidence = intptr
.output::<f32>(&session, "confidence")?
.create_host_tensor_from_device(true)
.as_ndarray::<ndarray::Ix3>()
.to_owned();
tracing::trace!("Output Confidence: \t{:?}", output_confidence.shape());
let output_landmark = intptr
.output::<f32>(&session, "landmark")?
.create_host_tensor_from_device(true)
.as_ndarray::<ndarray::Ix3>()
.to_owned();
tracing::trace!("Output Landmark: \t{:?}", output_landmark.shape());
Ok(FaceDetectionModelOutput {
bbox: output_tensor,
confidence: output_confidence,
landmark: output_landmark,
})
})
.map_err(|e| e.into_inner())
.change_context(Error)?;
Ok(output)
/// Detect faces with full pipeline including postprocessing
fn detect_faces(
&mut self,
image: ndarray::ArrayView3<u8>,
config: &FaceDetectionConfig,
) -> Result<FaceDetectionOutput> {
let (height, width, _channels) = image.dim();
let output = self
.run_model(image)
.change_context(Error)
.attach_printable("Failed to detect faces")?;
let processed = output
.postprocess(&config)
.attach_printable("Failed to postprocess")?;
apply_nms_and_finalize(processed, &config, (width, height))
}
}

View File

@@ -0,0 +1,146 @@
use crate::errors::*;
use crate::facedet::*;
use error_stack::ResultExt;
use mnn_bridge::ndarray::*;
use ndarray_resize::NdFir;
use std::path::Path;
#[derive(Debug)]
pub struct FaceDetection {
handle: mnn_sync::SessionHandle,
}
pub struct FaceDetectionBuilder {
schedule_config: Option<mnn::ScheduleConfig>,
backend_config: Option<mnn::BackendConfig>,
model: mnn::Interpreter,
}
impl FaceDetectionBuilder {
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
Ok(Self {
schedule_config: None,
backend_config: None,
model: mnn::Interpreter::from_bytes(model.as_ref())
.map_err(|e| e.into_inner())
.change_context(Error)
.attach_printable("Failed to load model from bytes")?,
})
}
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
self.schedule_config
.get_or_insert_default()
.set_type(forward_type);
self
}
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
self.schedule_config = Some(config);
self
}
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
self.backend_config = Some(config);
self
}
pub fn build(self) -> Result<FaceDetection> {
let model = self.model;
let sc = self.schedule_config.unwrap_or_default();
let handle = mnn_sync::SessionHandle::new(model, sc)
.change_context(Error)
.attach_printable("Failed to create session handle")?;
Ok(FaceDetection { handle })
}
}
impl FaceDetection {
pub fn builder<T: AsRef<[u8]>>(
model: T,
) -> std::result::Result<FaceDetectionBuilder, error_stack::Report<crate::errors::Error>> {
FaceDetectionBuilder::new(model)
}
}
impl FaceDetector for FaceDetection {
fn run_model(&mut self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput> {
#[rustfmt::skip]
let mut resized = image
.fast_resize(1024, 1024, None)
.change_context(Error)?
.mapv(|f| f as f32);
// Apply mean subtraction: [104, 117, 123]
resized
.axis_iter_mut(ndarray::Axis(2))
.zip([104, 117, 123])
.for_each(|(mut array, pixel)| {
let pixel = pixel as f32;
array.map_inplace(|v| *v -= pixel);
});
let mut resized = resized
.permuted_axes((2, 0, 1))
.insert_axis(ndarray::Axis(0))
.as_standard_layout()
.into_owned();
use ::tap::*;
let output = self
.handle
.run(move |sr| {
let tensor = resized
.as_mnn_tensor_mut()
.attach_printable("Failed to convert ndarray to mnn tensor")
.change_context(mnn::error::ErrorKind::TensorError)?;
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
let (intptr, session) = sr.both_mut();
tracing::trace!("Copying input tensor to host");
unsafe {
let mut input = intptr.input_unresized::<f32>(session, "input")?;
tracing::trace!("Input shape: {:?}", input.shape());
intptr.resize_tensor_by_nchw::<mnn::View<&mut f32>, _>(
input.view_mut(),
1,
3,
1024,
1024,
);
}
intptr.resize_session(session);
let mut input = intptr.input::<f32>(session, "input")?;
tracing::trace!("Input shape: {:?}", input.shape());
input.copy_from_host_tensor(tensor.view())?;
tracing::info!("Running face detection session");
intptr.run_session(&session)?;
let output_tensor = intptr
.output::<f32>(&session, "bbox")?
.create_host_tensor_from_device(true)
.as_ndarray()
.to_owned();
tracing::trace!("Output Bbox: \t\t{:?}", output_tensor.shape());
let output_confidence = intptr
.output::<f32>(&session, "confidence")?
.create_host_tensor_from_device(true)
.as_ndarray::<ndarray::Ix3>()
.to_owned();
tracing::trace!("Output Confidence: \t{:?}", output_confidence.shape());
let output_landmark = intptr
.output::<f32>(&session, "landmark")?
.create_host_tensor_from_device(true)
.as_ndarray::<ndarray::Ix3>()
.to_owned();
tracing::trace!("Output Landmark: \t{:?}", output_landmark.shape());
Ok(FaceDetectionModelOutput {
bbox: output_tensor,
confidence: output_confidence,
landmark: output_landmark,
})
})
.map_err(|e| e.into_inner())
.change_context(Error)?;
Ok(output)
}
}

View File

@@ -0,0 +1,256 @@
use crate::errors::*;
use crate::facedet::*;
use crate::ort_ep::*;
use error_stack::ResultExt;
use ndarray_resize::NdFir;
use ort::{
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
session::{Session, builder::GraphOptimizationLevel},
value::Tensor,
};
use std::path::Path;
#[derive(Debug)]
pub struct FaceDetection {
session: Session,
}
pub struct FaceDetectionBuilder {
model_data: Vec<u8>,
execution_providers: Option<Vec<ExecutionProviderDispatch>>,
intra_threads: Option<usize>,
inter_threads: Option<usize>,
}
impl FaceDetectionBuilder {
pub fn new(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
Ok(Self {
model_data: model.as_ref().to_vec(),
execution_providers: None,
intra_threads: None,
inter_threads: None,
})
}
pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self {
let execution_providers: Vec<ExecutionProviderDispatch> = providers
.as_ref()
.iter()
.filter_map(|provider| provider.to_dispatch())
.collect();
if !execution_providers.is_empty() {
self.execution_providers = Some(execution_providers);
} else {
tracing::warn!("No valid execution providers found, falling back to CPU");
self.execution_providers = Some(vec![CPUExecutionProvider::default().build()]);
}
self
}
pub fn with_intra_threads(mut self, threads: usize) -> Self {
self.intra_threads = Some(threads);
self
}
pub fn with_inter_threads(mut self, threads: usize) -> Self {
self.inter_threads = Some(threads);
self
}
pub fn build(self) -> crate::errors::Result<FaceDetection> {
let mut session_builder = Session::builder()
.change_context(Error)
.attach_printable("Failed to create session builder")?;
// Set execution providers
if let Some(providers) = self.execution_providers {
session_builder = session_builder
.with_execution_providers(providers)
.change_context(Error)
.attach_printable("Failed to set execution providers")?;
} else {
// Default to CPU
session_builder = session_builder
.with_execution_providers([CPUExecutionProvider::default().build()])
.change_context(Error)
.attach_printable("Failed to set default CPU execution provider")?;
}
// Set threading options
if let Some(threads) = self.intra_threads {
session_builder = session_builder
.with_intra_threads(threads)
.change_context(Error)
.attach_printable("Failed to set intra threads")?;
}
if let Some(threads) = self.inter_threads {
session_builder = session_builder
.with_inter_threads(threads)
.change_context(Error)
.attach_printable("Failed to set inter threads")?;
}
// Set optimization level
session_builder = session_builder
.with_optimization_level(GraphOptimizationLevel::Level3)
.change_context(Error)
.attach_printable("Failed to set optimization level")?;
// Create session from model bytes
let session = session_builder
.commit_from_memory(&self.model_data)
.change_context(Error)
.attach_printable("Failed to create ORT session from model bytes")?;
tracing::info!("Successfully created ORT RetinaFace session");
Ok(FaceDetection { session })
}
}
impl FaceDetection {
pub fn builder<T: AsRef<[u8]>>(
model: T,
) -> std::result::Result<FaceDetectionBuilder, error_stack::Report<crate::errors::Error>> {
FaceDetectionBuilder::new(model)
}
pub fn new(path: impl AsRef<Path>) -> crate::errors::Result<Self> {
let model = std::fs::read(path)
.change_context(Error)
.attach_printable("Failed to read model file")?;
Self::new_from_bytes(&model)
}
pub fn new_from_bytes(model: &[u8]) -> crate::errors::Result<Self> {
tracing::info!("Loading ORT RetinaFace model from bytes");
Self::builder(model)?.build()
}
}
impl FaceDetector for FaceDetection {
fn run_model(
&mut self,
image: ndarray::ArrayView3<u8>,
) -> crate::errors::Result<FaceDetectionModelOutput> {
// Resize image to 1024x1024
let mut resized = image
.fast_resize(1024, 1024, None)
.change_context(Error)
.attach_printable("Failed to resize image")?
.mapv(|f| f as f32);
// Apply mean subtraction: [104, 117, 123] for BGR format
resized
.axis_iter_mut(ndarray::Axis(2))
.zip([104.0, 117.0, 123.0])
.for_each(|(mut array, mean)| {
array.map_inplace(|v| *v -= mean);
});
// Convert from HWC to NCHW format (add batch dimension and transpose)
let input_tensor = resized
.permuted_axes((2, 0, 1))
.insert_axis(ndarray::Axis(0))
.as_standard_layout()
.into_owned();
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
// Create ORT input tensor
let input_value = Tensor::from_array(input_tensor)
.change_context(Error)
.attach_printable("Failed to create input tensor")?;
// Run inference
tracing::debug!("Running ORT RetinaFace inference");
let outputs = self
.session
.run(ort::inputs!["input" => input_value])
.change_context(Error)
.attach_printable("Failed to run inference")?;
// Extract outputs by name
let bbox_output = outputs
.get("bbox")
.ok_or(Error)
.attach_printable("Missing bbox output from model")?
.try_extract_tensor::<f32>()
.change_context(Error)
.attach_printable("Failed to extract bbox tensor")?;
let confidence_output = outputs
.get("confidence")
.ok_or(Error)
.attach_printable("Missing confidence output from model")?
.try_extract_tensor::<f32>()
.change_context(Error)
.attach_printable("Failed to extract confidence tensor")?;
let landmark_output = outputs
.get("landmark")
.ok_or(Error)
.attach_printable("Missing landmark output from model")?
.try_extract_tensor::<f32>()
.change_context(Error)
.attach_printable("Failed to extract landmark tensor")?;
// Get tensor shapes and data
let (bbox_shape, bbox_data) = bbox_output;
let (confidence_shape, confidence_data) = confidence_output;
let (landmark_shape, landmark_data) = landmark_output;
tracing::trace!(
"Output shapes - bbox: {:?}, confidence: {:?}, landmark: {:?}",
bbox_shape,
confidence_shape,
landmark_shape
);
// Convert to ndarray format
let bbox_dims = bbox_shape.as_ref();
let confidence_dims = confidence_shape.as_ref();
let landmark_dims = landmark_shape.as_ref();
let bbox_array = ndarray::Array3::from_shape_vec(
(
bbox_dims[0] as usize,
bbox_dims[1] as usize,
bbox_dims[2] as usize,
),
bbox_data.to_vec(),
)
.change_context(Error)
.attach_printable("Failed to create bbox ndarray")?;
let confidence_array = ndarray::Array3::from_shape_vec(
(
confidence_dims[0] as usize,
confidence_dims[1] as usize,
confidence_dims[2] as usize,
),
confidence_data.to_vec(),
)
.change_context(Error)
.attach_printable("Failed to create confidence ndarray")?;
let landmark_array = ndarray::Array3::from_shape_vec(
(
landmark_dims[0] as usize,
landmark_dims[1] as usize,
landmark_dims[2] as usize,
),
landmark_data.to_vec(),
)
.change_context(Error)
.attach_printable("Failed to create landmark ndarray")?;
Ok(FaceDetectionModelOutput {
bbox: bbox_array,
confidence: confidence_array,
landmark: landmark_array,
})
}
}

35
src/faceembed.rs Normal file
View File

@@ -0,0 +1,35 @@
pub mod facenet;
// Re-export common types and traits
pub use facenet::FaceNetEmbedder;
pub use facenet::{FaceEmbedding, FaceEmbeddingConfig, IntoEmbeddings};
// Convenience type aliases for different backends
pub use facenet::mnn::EmbeddingGenerator as MnnEmbeddingGenerator;
pub use facenet::ort::EmbeddingGenerator as OrtEmbeddingGenerator;
use crate::errors::*;
use ndarray::{Array2, ArrayView4};
pub mod preprocessing {
use ndarray::*;
pub fn preprocess(faces: ArrayView4<u8>) -> Array4<f32> {
let mut owned = faces.as_standard_layout().mapv(|v| v as f32).to_owned();
owned.axis_iter_mut(Axis(0)).for_each(|mut image| {
let mean = image.mean().unwrap_or(0.0);
let std = image.std(0.0);
if std > 0.0 {
image.mapv_inplace(|x| (x - mean) / std);
} else {
image.mapv_inplace(|x| (x - 127.5) / 128.0)
}
});
owned
}
}
/// Common trait for face embedding backends - maintained for backward compatibility
pub trait FaceEmbedder {
/// Generate embeddings for a batch of face images
fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
}

209
src/faceembed/facenet.rs Normal file
View File

@@ -0,0 +1,209 @@
pub mod mnn;
pub mod ort;
use crate::errors::*;
use error_stack::ResultExt;
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
use ndarray_math::{CosineSimilarity, EuclideanDistance};
/// Configuration for face embedding processing
#[derive(Debug, Clone, PartialEq)]
pub struct FaceEmbeddingConfig {
/// Input image width expected by the model
pub input_width: usize,
/// Input image height expected by the model
pub input_height: usize,
/// Whether to normalize embeddings to unit vectors
pub normalize: bool,
}
impl FaceEmbeddingConfig {
pub fn with_input_size(mut self, width: usize, height: usize) -> Self {
self.input_width = width;
self.input_height = height;
self
}
pub fn with_normalization(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
}
impl Default for FaceEmbeddingConfig {
fn default() -> Self {
Self {
input_width: 320,
input_height: 320,
normalize: false,
}
}
}
/// Represents a face embedding vector
#[derive(Debug, Clone, PartialEq)]
pub struct FaceEmbedding {
/// The embedding vector
pub vector: Array1<f32>,
/// Optional confidence score for the embedding quality
pub confidence: Option<f32>,
}
impl FaceEmbedding {
pub fn new(vector: Array1<f32>) -> Self {
Self {
vector,
confidence: None,
}
}
pub fn with_confidence(mut self, confidence: f32) -> Self {
self.confidence = Some(confidence);
self
}
/// Calculate cosine similarity with another embedding
pub fn cosine_similarity(&self, other: &FaceEmbedding) -> f32 {
self.vector.cosine_similarity(&other.vector).unwrap_or(0.0)
}
/// Calculate Euclidean distance with another embedding
pub fn euclidean_distance(&self, other: &FaceEmbedding) -> f32 {
self.vector
.euclidean_distance(other.vector.view())
.unwrap_or(f32::INFINITY)
}
/// Normalize the embedding vector to unit length
pub fn normalize(&mut self) {
let norm = self.vector.mapv(|x| x * x).sum().sqrt();
if norm > 0.0 {
self.vector.mapv_inplace(|x| x / norm);
}
}
/// Get the dimensionality of the embedding
pub fn dimension(&self) -> usize {
self.vector.len()
}
}
/// Raw model outputs that can be converted to embeddings
pub trait IntoEmbeddings {
fn into_embeddings(self, config: &FaceEmbeddingConfig) -> Result<Vec<FaceEmbedding>>;
}
impl IntoEmbeddings for Array2<f32> {
fn into_embeddings(self, config: &FaceEmbeddingConfig) -> Result<Vec<FaceEmbedding>> {
let mut embeddings = Vec::new();
for row in self.rows() {
let mut vector = row.to_owned();
if config.normalize {
let norm = vector.mapv(|x| x * x).sum().sqrt();
if norm > 0.0 {
vector.mapv_inplace(|x| x / norm);
}
}
embeddings.push(FaceEmbedding::new(vector));
}
Ok(embeddings)
}
}
/// Common trait for face embedding backends
pub trait FaceNetEmbedder {
/// Generate embeddings for a batch of face images
fn run_model(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
/// Generate embeddings with full pipeline including postprocessing
fn generate_embeddings(
&mut self,
faces: ArrayView4<u8>,
config: FaceEmbeddingConfig,
) -> Result<Vec<FaceEmbedding>> {
let raw_output = self
.run_model(faces)
.change_context(Error)
.attach_printable("Failed to generate embeddings")?;
raw_output
.into_embeddings(&config)
.attach_printable("Failed to process embeddings")
}
/// Generate a single embedding from a single face image
fn generate_embedding(
&mut self,
face: ArrayView3<u8>,
config: FaceEmbeddingConfig,
) -> Result<FaceEmbedding> {
// Add batch dimension
let face_batch = face.insert_axis(ndarray::Axis(0));
let embeddings = self.generate_embeddings(face_batch.view(), config)?;
embeddings
.into_iter()
.next()
.ok_or(Error)
.attach_printable("No embedding generated for input face")
}
}
/// Utility functions for embedding processing
pub mod utils {
use super::*;
/// Compute pairwise cosine similarities between two sets of embeddings
pub fn pairwise_cosine_similarities(
embeddings1: &[FaceEmbedding],
embeddings2: &[FaceEmbedding],
) -> Array2<f32> {
let n1 = embeddings1.len();
let n2 = embeddings2.len();
let mut similarities = Array2::zeros((n1, n2));
for (i, emb1) in embeddings1.iter().enumerate() {
for (j, emb2) in embeddings2.iter().enumerate() {
similarities[(i, j)] = emb1.cosine_similarity(emb2);
}
}
similarities
}
/// Find the best matching embedding from a gallery for each query
pub fn find_best_matches(
queries: &[FaceEmbedding],
gallery: &[FaceEmbedding],
) -> Vec<(usize, f32)> {
let similarities = pairwise_cosine_similarities(queries, gallery);
let mut best_matches = Vec::new();
for i in 0..queries.len() {
let row = similarities.row(i);
let (best_idx, best_score) = row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap();
best_matches.push((best_idx, *best_score));
}
best_matches
}
/// Filter embeddings by minimum quality threshold
pub fn filter_by_confidence(
embeddings: Vec<FaceEmbedding>,
min_confidence: f32,
) -> Vec<FaceEmbedding> {
embeddings
.into_iter()
.filter(|emb| emb.confidence.map_or(true, |conf| conf >= min_confidence))
.collect()
}
}

View File

@@ -0,0 +1,127 @@
use crate::errors::*;
use crate::faceembed::facenet::FaceNetEmbedder;
use mnn_bridge::ndarray::*;
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
use std::path::Path;
#[derive(Debug)]
pub struct EmbeddingGenerator {
handle: mnn_sync::SessionHandle,
}
pub struct EmbeddingGeneratorBuilder {
schedule_config: Option<mnn::ScheduleConfig>,
backend_config: Option<mnn::BackendConfig>,
model: mnn::Interpreter,
}
impl EmbeddingGeneratorBuilder {
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
Ok(Self {
schedule_config: None,
backend_config: None,
model: mnn::Interpreter::from_bytes(model.as_ref())
.map_err(|e| e.into_inner())
.change_context(Error)
.attach_printable("Failed to load model from bytes")?,
})
}
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
self.schedule_config
.get_or_insert_default()
.set_type(forward_type);
self
}
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
self.schedule_config = Some(config);
self
}
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
self.backend_config = Some(config);
self
}
pub fn build(self) -> Result<EmbeddingGenerator> {
let model = self.model;
let sc = self.schedule_config.unwrap_or_default();
let handle = mnn_sync::SessionHandle::new(model, sc)
.change_context(Error)
.attach_printable("Failed to create session handle")?;
Ok(EmbeddingGenerator { handle })
}
}
impl EmbeddingGenerator {
const INPUT_NAME: &'static str = "serving_default_input_6:0";
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
pub fn builder<T: AsRef<[u8]>>(
model: T,
) -> std::result::Result<EmbeddingGeneratorBuilder, Report<Error>> {
EmbeddingGeneratorBuilder::new(model)
}
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
let tensor = crate::faceembed::preprocessing::preprocess(face);
let shape: [usize; 4] = tensor.dim().into();
let shape = shape.map(|f| f as i32);
let output = self
.handle
.run(move |sr| {
let tensor = tensor
.as_mnn_tensor()
.attach_printable("Failed to convert ndarray to mnn tensor")
.change_context(mnn::ErrorKind::TensorError)?;
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
let (intptr, session) = sr.both_mut();
tracing::trace!("Copying input tensor to host");
let needs_resize = unsafe {
let mut input = intptr.input_unresized::<f32>(session, Self::INPUT_NAME)?;
tracing::trace!("Input shape: {:?}", input.shape());
if *input.shape() != shape {
tracing::trace!("Resizing input tensor to shape: {:?}", shape);
// input.resize(shape);
intptr.resize_tensor(input.view_mut(), shape);
true
} else {
false
}
};
if needs_resize {
tracing::trace!("Resized input tensor to shape: {:?}", shape);
let now = std::time::Instant::now();
intptr.resize_session(session);
tracing::trace!("Session resized in {:?}", now.elapsed());
}
let mut input = intptr.input::<f32>(session, Self::INPUT_NAME)?;
tracing::trace!("Input shape: {:?}", input.shape());
input.copy_from_host_tensor(tensor.view())?;
tracing::info!("Running face detection session");
intptr.run_session(&session)?;
let output_tensor = intptr
.output::<f32>(&session, Self::OUTPUT_NAME)?
.create_host_tensor_from_device(true)
.as_ndarray()
.to_owned();
Ok(output_tensor)
})
.change_context(Error)?;
Ok(output)
}
}
impl FaceNetEmbedder for EmbeddingGenerator {
fn run_model(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
self.run_models(faces)
}
}
// Main trait implementation for backward compatibility
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
EmbeddingGenerator::run_models(self, faces)
}
}

View File

@@ -0,0 +1,207 @@
use crate::errors::*;
use crate::faceembed::facenet::FaceNetEmbedder;
use crate::ort_ep::*;
use error_stack::ResultExt;
use ndarray::{Array2, ArrayView4};
use ort::{
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
session::{Session, builder::GraphOptimizationLevel},
value::Tensor,
};
use std::path::Path;
#[derive(Debug)]
pub struct EmbeddingGenerator {
session: Session,
}
pub struct EmbeddingGeneratorBuilder {
model_data: Vec<u8>,
execution_providers: Option<Vec<ExecutionProviderDispatch>>,
intra_threads: Option<usize>,
inter_threads: Option<usize>,
}
impl EmbeddingGeneratorBuilder {
pub fn new(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
Ok(Self {
model_data: model.as_ref().to_vec(),
execution_providers: None,
intra_threads: None,
inter_threads: None,
})
}
pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self {
let execution_providers: Vec<ExecutionProviderDispatch> = providers
.as_ref()
.iter()
.filter_map(|provider| provider.to_dispatch())
.collect();
if !execution_providers.is_empty() {
self.execution_providers = Some(execution_providers);
} else {
tracing::warn!("No valid execution providers found, falling back to CPU");
self.execution_providers = Some(vec![CPUExecutionProvider::default().build()]);
}
self
}
pub fn with_intra_threads(mut self, threads: usize) -> Self {
self.intra_threads = Some(threads);
self
}
pub fn with_inter_threads(mut self, threads: usize) -> Self {
self.inter_threads = Some(threads);
self
}
pub fn build(self) -> crate::errors::Result<EmbeddingGenerator> {
let mut session_builder = Session::builder()
.change_context(Error)
.attach_printable("Failed to create session builder")?;
// Set execution providers
if let Some(providers) = self.execution_providers {
session_builder = session_builder
.with_execution_providers(providers)
.change_context(Error)
.attach_printable("Failed to set execution providers")?;
} else {
// Default to CPU
session_builder = session_builder
.with_execution_providers([CPUExecutionProvider::default().build()])
.change_context(Error)
.attach_printable("Failed to set default CPU execution provider")?;
}
// Set threading options
if let Some(threads) = self.intra_threads {
session_builder = session_builder
.with_intra_threads(threads)
.change_context(Error)
.attach_printable("Failed to set intra threads")?;
}
if let Some(threads) = self.inter_threads {
session_builder = session_builder
.with_inter_threads(threads)
.change_context(Error)
.attach_printable("Failed to set inter threads")?;
}
// Set optimization level
session_builder = session_builder
.with_optimization_level(GraphOptimizationLevel::Level3)
.change_context(Error)
.attach_printable("Failed to set optimization level")?;
// Create session from model bytes
let session = session_builder
.commit_from_memory(&self.model_data)
.change_context(Error)
.attach_printable("Failed to create ORT session from model bytes")?;
tracing::info!("Successfully created ORT FaceNet session");
Ok(EmbeddingGenerator { session })
}
}
impl EmbeddingGenerator {
const INPUT_NAME: &'static str = "serving_default_input_6:0";
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
pub fn builder<T: AsRef<[u8]>>(
model: T,
) -> std::result::Result<EmbeddingGeneratorBuilder, error_stack::Report<crate::errors::Error>>
{
EmbeddingGeneratorBuilder::new(model)
}
pub fn new(path: impl AsRef<Path>) -> crate::errors::Result<Self> {
let model = std::fs::read(path)
.change_context(Error)
.attach_printable("Failed to read model file")?;
Self::new_from_bytes(&model)
}
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
tracing::info!("Loading ORT face embedding model from bytes");
Self::builder(model)?.build()
}
pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
// Convert input from u8 to f32 and normalize to [0, 1] range
let input_tensor = crate::faceembed::preprocessing::preprocess(faces);
// face_array = np.asarray(face_resized, 'float32')
// mean, std = face_array.mean(), face_array.std()
// face_normalized = (face_array - mean) / std
// let input_tensor = faces.mean()
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
// Create ORT input tensor
let input_value = Tensor::from_array(input_tensor)
.change_context(Error)
.attach_printable("Failed to create input tensor")?;
// Run inference
tracing::debug!("Running ORT FaceNet inference");
let outputs = self
.session
.run(ort::inputs![Self::INPUT_NAME => input_value])
.change_context(Error)
.attach_printable("Failed to run inference")?;
// Extract output tensor
let output = outputs
.get(Self::OUTPUT_NAME)
.ok_or(Error)
.attach_printable("Missing output from FaceNet model")?
.try_extract_tensor::<f32>()
.change_context(Error)
.attach_printable("Failed to extract output tensor")?;
let (output_shape, output_data) = output;
tracing::trace!("Output shape: {:?}", output_shape);
// Convert to ndarray format
let output_dims = output_shape.as_ref();
// FaceNet typically outputs embeddings as [batch_size, embedding_dim]
let batch_size = output_dims[0] as usize;
let embedding_dim = output_dims[1] as usize;
let output_array =
ndarray::Array2::from_shape_vec((batch_size, embedding_dim), output_data.to_vec())
.change_context(Error)
.attach_printable("Failed to create output ndarray")?;
tracing::trace!(
"Generated embeddings with shape: {:?}",
output_array.shape()
);
Ok(output_array)
}
}
impl FaceNetEmbedder for EmbeddingGenerator {
fn run_model(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
self.run_models(faces)
}
}
// Main trait implementation for backward compatibility
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
// Need to create a mutable reference for the session
// This is a workaround for the trait signature mismatch
self.run_models(faces)
}
}

891
src/gui/app.rs Normal file
View File

@@ -0,0 +1,891 @@
use iced::{
Alignment, Element, Length, Task, Theme,
widget::{
Space, button, column, container, image, pick_list, progress_bar, row, scrollable, slider,
text,
},
};
use rfd::FileDialog;
use std::path::PathBuf;
use std::sync::Arc;
use crate::gui::bridge::FaceDetectionBridge;
#[derive(Debug, Clone)]
pub enum Message {
// File operations
OpenImageDialog,
ImageSelected(Option<PathBuf>),
OpenSecondImageDialog,
SecondImageSelected(Option<PathBuf>),
SaveOutputDialog,
OutputPathSelected(Option<PathBuf>),
// Detection parameters
ThresholdChanged(f32),
NmsThresholdChanged(f32),
ExecutorChanged(ExecutorType),
// Actions
DetectFaces,
CompareFaces,
ClearResults,
// Results
DetectionComplete(DetectionResult),
ComparisonComplete(ComparisonResult),
// UI state
TabChanged(Tab),
ProgressUpdate(f32),
// Image loading
ImageLoaded(Option<Arc<Vec<u8>>>),
SecondImageLoaded(Option<Arc<Vec<u8>>>),
ProcessedImageUpdated(Option<Vec<u8>>),
}
#[derive(Debug, Clone, PartialEq)]
pub enum Tab {
Detection,
Comparison,
Settings,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ExecutorType {
MnnCpu,
MnnMetal,
MnnCoreML,
OnnxCpu,
}
#[derive(Debug, Clone)]
pub enum DetectionResult {
Success {
image_path: PathBuf,
faces_count: usize,
processed_image: Option<Vec<u8>>,
processing_time: f64,
},
Error(String),
}
#[derive(Debug, Clone)]
pub enum ComparisonResult {
Success {
image1_faces: usize,
image2_faces: usize,
best_similarity: f32,
processing_time: f64,
},
Error(String),
}
#[derive(Debug)]
pub struct FaceDetectorApp {
// Current tab
current_tab: Tab,
// File paths
input_image: Option<PathBuf>,
second_image: Option<PathBuf>,
output_path: Option<PathBuf>,
// Detection parameters
threshold: f32,
nms_threshold: f32,
executor_type: ExecutorType,
// UI state
is_processing: bool,
progress: f32,
status_message: String,
// Results
detection_result: Option<DetectionResult>,
comparison_result: Option<ComparisonResult>,
// Image data for display
current_image_handle: Option<image::Handle>,
processed_image_handle: Option<image::Handle>,
second_image_handle: Option<image::Handle>,
}
impl Default for FaceDetectorApp {
fn default() -> Self {
Self {
current_tab: Tab::Detection,
input_image: None,
second_image: None,
output_path: None,
threshold: 0.8,
nms_threshold: 0.3,
executor_type: ExecutorType::MnnCpu,
is_processing: false,
progress: 0.0,
status_message: "Ready".to_string(),
detection_result: None,
comparison_result: None,
current_image_handle: None,
processed_image_handle: None,
second_image_handle: None,
}
}
}
impl FaceDetectorApp {
fn new() -> (Self, Task<Message>) {
(Self::default(), Task::none())
}
fn title(&self) -> String {
"Face Detector - Rust GUI".to_string()
}
fn update(&mut self, message: Message) -> Task<Message> {
match message {
Message::TabChanged(tab) => {
self.current_tab = tab;
Task::none()
}
Message::OpenImageDialog => {
self.status_message = "Opening file dialog...".to_string();
Task::perform(
async {
FileDialog::new()
.add_filter("Images", &["jpg", "jpeg", "png", "bmp", "tiff", "webp"])
.pick_file()
},
Message::ImageSelected,
)
}
Message::ImageSelected(path) => {
if let Some(path) = path {
self.input_image = Some(path.clone());
self.status_message = format!("Selected: {}", path.display());
// Load image data for display
Task::perform(
async move {
match std::fs::read(&path) {
Ok(data) => Some(Arc::new(data)),
Err(_) => None,
}
},
Message::ImageLoaded,
)
} else {
self.status_message = "No file selected".to_string();
Task::none()
}
}
Message::OpenSecondImageDialog => Task::perform(
async {
FileDialog::new()
.add_filter("Images", &["jpg", "jpeg", "png", "bmp", "tiff", "webp"])
.pick_file()
},
Message::SecondImageSelected,
),
Message::SecondImageSelected(path) => {
if let Some(path) = path {
self.second_image = Some(path.clone());
self.status_message = format!("Second image selected: {}", path.display());
// Load second image data for display
Task::perform(
async move {
match std::fs::read(&path) {
Ok(data) => Some(Arc::new(data)),
Err(_) => None,
}
},
Message::SecondImageLoaded,
)
} else {
self.status_message = "No second image selected".to_string();
Task::none()
}
}
Message::SaveOutputDialog => Task::perform(
async {
FileDialog::new()
.add_filter("Images", &["jpg", "jpeg", "png"])
.save_file()
},
Message::OutputPathSelected,
),
Message::OutputPathSelected(path) => {
if let Some(path) = path {
self.output_path = Some(path.clone());
self.status_message = format!("Output will be saved to: {}", path.display());
} else {
self.status_message = "No output path selected".to_string();
}
Task::none()
}
Message::ThresholdChanged(value) => {
self.threshold = value;
Task::none()
}
Message::NmsThresholdChanged(value) => {
self.nms_threshold = value;
Task::none()
}
Message::ExecutorChanged(executor_type) => {
self.executor_type = executor_type;
Task::none()
}
Message::DetectFaces => {
if let Some(input_path) = &self.input_image {
self.is_processing = true;
self.progress = 0.0;
self.status_message = "Detecting faces...".to_string();
let input_path = input_path.clone();
let output_path = self.output_path.clone();
let threshold = self.threshold;
let nms_threshold = self.nms_threshold;
let executor_type = self.executor_type.clone();
Task::perform(
async move {
FaceDetectionBridge::detect_faces(
input_path,
output_path,
threshold,
nms_threshold,
executor_type,
)
.await
},
Message::DetectionComplete,
)
} else {
self.status_message = "Please select an image first".to_string();
Task::none()
}
}
Message::CompareFaces => {
if let (Some(image1), Some(image2)) = (&self.input_image, &self.second_image) {
self.is_processing = true;
self.progress = 0.0;
self.status_message = "Comparing faces...".to_string();
let image1 = image1.clone();
let image2 = image2.clone();
let threshold = self.threshold;
let nms_threshold = self.nms_threshold;
let executor_type = self.executor_type.clone();
Task::perform(
async move {
FaceDetectionBridge::compare_faces(
image1,
image2,
threshold,
nms_threshold,
executor_type,
)
.await
},
Message::ComparisonComplete,
)
} else {
self.status_message = "Please select both images for comparison".to_string();
Task::none()
}
}
Message::ClearResults => {
self.detection_result = None;
self.comparison_result = None;
self.processed_image_handle = None;
self.status_message = "Results cleared".to_string();
Task::none()
}
Message::DetectionComplete(result) => {
self.is_processing = false;
self.progress = 100.0;
match &result {
DetectionResult::Success {
faces_count,
processing_time,
processed_image,
..
} => {
self.status_message = format!(
"Detection complete! Found {} faces in {:.2}s",
faces_count, processing_time
);
// Update processed image if available
if let Some(image_data) = processed_image {
self.processed_image_handle =
Some(image::Handle::from_bytes(image_data.clone()));
}
}
DetectionResult::Error(error) => {
self.status_message = format!("Detection failed: {}", error);
}
}
self.detection_result = Some(result);
Task::none()
}
Message::ComparisonComplete(result) => {
self.is_processing = false;
self.progress = 100.0;
match &result {
ComparisonResult::Success {
best_similarity,
processing_time,
..
} => {
let interpretation = if *best_similarity > 0.8 {
"Very likely the same person"
} else if *best_similarity > 0.6 {
"Possibly the same person"
} else if *best_similarity > 0.4 {
"Unlikely to be the same person"
} else {
"Very unlikely to be the same person"
};
self.status_message = format!(
"Comparison complete! Similarity: {:.3} - {} (Processing time: {:.2}s)",
best_similarity, interpretation, processing_time
);
}
ComparisonResult::Error(error) => {
self.status_message = format!("Comparison failed: {}", error);
}
}
self.comparison_result = Some(result);
Task::none()
}
Message::ProgressUpdate(progress) => {
self.progress = progress;
Task::none()
}
Message::ImageLoaded(data) => {
if let Some(image_data) = data {
self.current_image_handle =
Some(image::Handle::from_bytes(image_data.as_ref().clone()));
self.status_message = "Image loaded successfully".to_string();
} else {
self.status_message = "Failed to load image".to_string();
}
Task::none()
}
Message::SecondImageLoaded(data) => {
if let Some(image_data) = data {
self.second_image_handle =
Some(image::Handle::from_bytes(image_data.as_ref().clone()));
self.status_message = "Second image loaded successfully".to_string();
} else {
self.status_message = "Failed to load second image".to_string();
}
Task::none()
}
Message::ProcessedImageUpdated(data) => {
if let Some(image_data) = data {
self.processed_image_handle = Some(image::Handle::from_bytes(image_data));
}
Task::none()
}
}
}
fn view(&self) -> Element<'_, Message> {
let tabs = row![
button("Detection")
.on_press(Message::TabChanged(Tab::Detection))
.style(if self.current_tab == Tab::Detection {
button::primary
} else {
button::secondary
}),
button("Comparison")
.on_press(Message::TabChanged(Tab::Comparison))
.style(if self.current_tab == Tab::Comparison {
button::primary
} else {
button::secondary
}),
button("Settings")
.on_press(Message::TabChanged(Tab::Settings))
.style(if self.current_tab == Tab::Settings {
button::primary
} else {
button::secondary
}),
]
.spacing(10)
.padding(10);
let content = match self.current_tab {
Tab::Detection => self.detection_view(),
Tab::Comparison => self.comparison_view(),
Tab::Settings => self.settings_view(),
};
let status_bar = container(
row![
text(&self.status_message),
Space::with_width(Length::Fill),
if self.is_processing {
Element::from(progress_bar(0.0..=100.0, self.progress))
} else {
Space::with_width(Length::Shrink).into()
}
]
.align_y(Alignment::Center)
.spacing(10),
)
.padding(10)
.style(container::bordered_box);
column![tabs, content, status_bar].into()
}
}
impl FaceDetectorApp {
fn detection_view(&self) -> Element<'_, Message> {
let file_section = column![
text("Input Image").size(18),
row![
button("Select Image").on_press(Message::OpenImageDialog),
text(
self.input_image
.as_ref()
.map(|p| p
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string())
.unwrap_or_else(|| "No image selected".to_string())
),
]
.spacing(10)
.align_y(Alignment::Center),
row![
button("Output Path").on_press(Message::SaveOutputDialog),
text(
self.output_path
.as_ref()
.map(|p| p
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string())
.unwrap_or_else(|| "Auto-generate".to_string())
),
]
.spacing(10)
.align_y(Alignment::Center),
]
.spacing(10);
// Image display section
let image_section = if let Some(ref handle) = self.current_image_handle {
let original_image = column![
text("Original Image").size(16),
container(
image(handle.clone())
.width(400)
.height(300)
.content_fit(iced::ContentFit::ScaleDown)
)
.style(container::bordered_box)
.padding(5),
]
.spacing(5)
.align_x(Alignment::Center);
let processed_section = if let Some(ref processed_handle) = self.processed_image_handle
{
column![
text("Detected Faces").size(16),
container(
image(processed_handle.clone())
.width(400)
.height(300)
.content_fit(iced::ContentFit::ScaleDown)
)
.style(container::bordered_box)
.padding(5),
]
.spacing(5)
.align_x(Alignment::Center)
} else {
column![
text("Detected Faces").size(16),
container(
text("Process image to see results").style(|_theme| text::Style {
color: Some(iced::Color::from_rgb(0.6, 0.6, 0.6)),
})
)
.width(400)
.height(300)
.style(container::bordered_box)
.padding(5)
.center_x(Length::Fill)
.center_y(Length::Fill),
]
.spacing(5)
.align_x(Alignment::Center)
};
row![original_image, processed_section]
.spacing(20)
.align_y(Alignment::Start)
} else {
row![
container(
text("Select an image to display").style(|_theme| text::Style {
color: Some(iced::Color::from_rgb(0.6, 0.6, 0.6)),
})
)
.width(400)
.height(300)
.style(container::bordered_box)
.padding(5)
.center_x(Length::Fill)
.center_y(Length::Fill)
]
};
let controls = column![
text("Detection Parameters").size(18),
row![
text("Threshold:"),
slider(0.1..=1.0, self.threshold, Message::ThresholdChanged).step(0.01),
text(format!("{:.2}", self.threshold)),
]
.spacing(10)
.align_y(Alignment::Center),
row![
text("NMS Threshold:"),
slider(0.1..=1.0, self.nms_threshold, Message::NmsThresholdChanged).step(0.01),
text(format!("{:.2}", self.nms_threshold)),
]
.spacing(10)
.align_y(Alignment::Center),
row![
button("Detect Faces")
.on_press(Message::DetectFaces)
.style(button::primary),
button("Clear Results").on_press(Message::ClearResults),
]
.spacing(10),
]
.spacing(10);
let results = if let Some(result) = &self.detection_result {
match result {
DetectionResult::Success {
faces_count,
processing_time,
..
} => column![
text("Detection Results").size(18),
text(format!("Faces detected: {}", faces_count)),
text(format!("Processing time: {:.2}s", processing_time)),
]
.spacing(5),
DetectionResult::Error(error) => column![
text("Detection Results").size(18),
text(format!("Error: {}", error)).style(text::danger),
]
.spacing(5),
}
} else {
column![text("No results yet").style(|_theme| text::Style {
color: Some(iced::Color::from_rgb(0.6, 0.6, 0.6)),
})]
};
column![file_section, image_section, controls, results]
.spacing(20)
.padding(20)
.into()
}
fn comparison_view(&self) -> Element<'_, Message> {
let file_section = column![
text("Image Comparison").size(18),
row![
button("Select First Image").on_press(Message::OpenImageDialog),
text(
self.input_image
.as_ref()
.map(|p| p
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string())
.unwrap_or_else(|| "No image selected".to_string())
),
]
.spacing(10)
.align_y(Alignment::Center),
row![
button("Select Second Image").on_press(Message::OpenSecondImageDialog),
text(
self.second_image
.as_ref()
.map(|p| p
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string())
.unwrap_or_else(|| "No image selected".to_string())
),
]
.spacing(10)
.align_y(Alignment::Center),
]
.spacing(10);
// Image comparison display section
let comparison_image_section = {
let first_image = if let Some(ref handle) = self.current_image_handle {
column![
text("First Image").size(16),
container(
image(handle.clone())
.width(350)
.height(250)
.content_fit(iced::ContentFit::ScaleDown)
)
.style(container::bordered_box)
.padding(5),
]
.spacing(5)
.align_x(Alignment::Center)
} else {
column![
text("First Image").size(16),
container(text("Select first image").style(|_theme| text::Style {
color: Some(iced::Color::from_rgb(0.6, 0.6, 0.6)),
}))
.width(350)
.height(250)
.style(container::bordered_box)
.padding(5)
.center_x(Length::Fill)
.center_y(Length::Fill),
]
.spacing(5)
.align_x(Alignment::Center)
};
let second_image = if let Some(ref handle) = self.second_image_handle {
column![
text("Second Image").size(16),
container(
image(handle.clone())
.width(350)
.height(250)
.content_fit(iced::ContentFit::ScaleDown)
)
.style(container::bordered_box)
.padding(5),
]
.spacing(5)
.align_x(Alignment::Center)
} else {
column![
text("Second Image").size(16),
container(text("Select second image").style(|_theme| text::Style {
color: Some(iced::Color::from_rgb(0.6, 0.6, 0.6)),
}))
.width(350)
.height(250)
.style(container::bordered_box)
.padding(5)
.center_x(Length::Fill)
.center_y(Length::Fill),
]
.spacing(5)
.align_x(Alignment::Center)
};
row![first_image, second_image]
.spacing(20)
.align_y(Alignment::Start)
};
let controls = column![
text("Comparison Parameters").size(18),
row![
text("Threshold:"),
slider(0.1..=1.0, self.threshold, Message::ThresholdChanged).step(0.01),
text(format!("{:.2}", self.threshold)),
]
.spacing(10)
.align_y(Alignment::Center),
row![
text("NMS Threshold:"),
slider(0.1..=1.0, self.nms_threshold, Message::NmsThresholdChanged).step(0.01),
text(format!("{:.2}", self.nms_threshold)),
]
.spacing(10)
.align_y(Alignment::Center),
button("Compare Faces")
.on_press(Message::CompareFaces)
.style(button::primary),
]
.spacing(10);
let results = if let Some(result) = &self.comparison_result {
match result {
ComparisonResult::Success {
image1_faces,
image2_faces,
best_similarity,
processing_time,
} => {
let interpretation = if *best_similarity > 0.8 {
(
"Very likely the same person",
iced::Color::from_rgb(0.2, 0.8, 0.2),
)
} else if *best_similarity > 0.6 {
(
"Possibly the same person",
iced::Color::from_rgb(0.8, 0.8, 0.2),
)
} else if *best_similarity > 0.4 {
(
"Unlikely to be the same person",
iced::Color::from_rgb(0.8, 0.6, 0.2),
)
} else {
(
"Very unlikely to be the same person",
iced::Color::from_rgb(0.8, 0.2, 0.2),
)
};
column![
text("Comparison Results").size(18),
text(format!("First image faces: {}", image1_faces)),
text(format!("Second image faces: {}", image2_faces)),
text(format!("Best similarity: {:.3}", best_similarity)),
text(interpretation.0).style(move |_theme| text::Style {
color: Some(interpretation.1),
}),
text(format!("Processing time: {:.2}s", processing_time)),
]
.spacing(5)
}
ComparisonResult::Error(error) => column![
text("Comparison Results").size(18),
text(format!("Error: {}", error)).style(text::danger),
]
.spacing(5),
}
} else {
column![
text("No comparison results yet").style(|_theme| text::Style {
color: Some(iced::Color::from_rgb(0.6, 0.6, 0.6)),
})
]
};
column![file_section, comparison_image_section, controls, results]
.spacing(20)
.padding(20)
.into()
}
fn settings_view(&self) -> Element<'_, Message> {
let executor_options = vec![
ExecutorType::MnnCpu,
ExecutorType::MnnMetal,
ExecutorType::MnnCoreML,
ExecutorType::OnnxCpu,
];
container(
column![
text("Model Settings").size(18),
row![
text("Execution Backend:"),
pick_list(
executor_options,
Some(self.executor_type.clone()),
Message::ExecutorChanged,
),
]
.spacing(10)
.align_y(Alignment::Center),
text("Detection Thresholds").size(18),
row![
text("Detection Threshold:"),
slider(0.1..=1.0, self.threshold, Message::ThresholdChanged).step(0.01),
text(format!("{:.2}", self.threshold)),
]
.spacing(10)
.align_y(Alignment::Center),
row![
text("NMS Threshold:"),
slider(0.1..=1.0, self.nms_threshold, Message::NmsThresholdChanged).step(0.01),
text(format!("{:.2}", self.nms_threshold)),
]
.spacing(10)
.align_y(Alignment::Center),
text("About").size(18),
text("Face Detection and Embedding - Rust GUI"),
text("Built with iced.rs and your face detection engine"),
]
.spacing(15)
.padding(20),
)
.height(Length::Shrink)
.into()
}
}
impl std::fmt::Display for ExecutorType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExecutorType::MnnCpu => write!(f, "MNN (CPU)"),
ExecutorType::MnnMetal => write!(f, "MNN (Metal)"),
ExecutorType::MnnCoreML => write!(f, "MNN (CoreML)"),
ExecutorType::OnnxCpu => write!(f, "ONNX (CPU)"),
}
}
}
pub fn run() -> iced::Result {
iced::application(
"Face Detector",
FaceDetectorApp::update,
FaceDetectorApp::view,
)
.run_with(FaceDetectorApp::new)
}

367
src/gui/bridge.rs Normal file
View File

@@ -0,0 +1,367 @@
use std::path::PathBuf;
use crate::facedet::{FaceDetectionConfig, FaceDetector, retinaface};
use crate::faceembed::facenet;
use crate::gui::app::{ComparisonResult, DetectionResult, ExecutorType};
use ndarray_image::ImageToNdarray;
const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../../models/retinaface.mnn");
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../../models/facenet.mnn");
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../../models/retinaface.onnx");
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../../models/facenet.onnx");
pub struct FaceDetectionBridge;
impl FaceDetectionBridge {
pub async fn detect_faces(
image_path: PathBuf,
output_path: Option<PathBuf>,
threshold: f32,
nms_threshold: f32,
executor_type: ExecutorType,
) -> DetectionResult {
let start_time = std::time::Instant::now();
match Self::run_detection_internal(
image_path.clone(),
output_path,
threshold,
nms_threshold,
executor_type,
)
.await
{
Ok((faces_count, processed_image)) => {
let processing_time = start_time.elapsed().as_secs_f64();
DetectionResult::Success {
image_path,
faces_count,
processed_image,
processing_time,
}
}
Err(error) => DetectionResult::Error(error.to_string()),
}
}
pub async fn compare_faces(
image1_path: PathBuf,
image2_path: PathBuf,
threshold: f32,
nms_threshold: f32,
executor_type: ExecutorType,
) -> ComparisonResult {
let start_time = std::time::Instant::now();
match Self::run_comparison_internal(
image1_path,
image2_path,
threshold,
nms_threshold,
executor_type,
)
.await
{
Ok((image1_faces, image2_faces, best_similarity)) => {
let processing_time = start_time.elapsed().as_secs_f64();
ComparisonResult::Success {
image1_faces,
image2_faces,
best_similarity,
processing_time,
}
}
Err(error) => ComparisonResult::Error(error.to_string()),
}
}
async fn run_detection_internal(
image_path: PathBuf,
output_path: Option<PathBuf>,
threshold: f32,
nms_threshold: f32,
executor_type: ExecutorType,
) -> Result<(usize, Option<Vec<u8>>), Box<dyn std::error::Error + Send + Sync>> {
// Load the image
let img = image::open(&image_path)?;
let img_rgb = img.to_rgb8();
// Convert to ndarray format
let image_array = img_rgb.as_ndarray()?;
// Create detection configuration
let config = FaceDetectionConfig::default()
.with_threshold(threshold)
.with_nms_threshold(nms_threshold)
.with_input_width(1024)
.with_input_height(1024);
// Create detector and detect faces
let faces = match executor_type {
ExecutorType::MnnCpu | ExecutorType::MnnMetal | ExecutorType::MnnCoreML => {
let forward_type = match executor_type {
ExecutorType::MnnCpu => mnn::ForwardType::CPU,
ExecutorType::MnnMetal => mnn::ForwardType::Metal,
ExecutorType::MnnCoreML => mnn::ForwardType::CoreML,
_ => unreachable!(),
};
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_forward_type(forward_type)
.build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
detector
.detect_faces(image_array.view(), &config)
.map_err(|e| format!("Detection failed: {}", e))?
}
ExecutorType::OnnxCpu => {
let mut detector = retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.map_err(|e| format!("Failed to create ONNX detector: {}", e))?
.build()
.map_err(|e| format!("Failed to build ONNX detector: {}", e))?;
detector
.detect_faces(image_array.view(), &config)
.map_err(|e| format!("Detection failed: {}", e))?
}
};
let faces_count = faces.bbox.len();
// Generate output image with bounding boxes if requested
let processed_image = if output_path.is_some() || true {
// Always generate for GUI display
let mut output_img = img.to_rgb8();
for bbox in &faces.bbox {
let min_point = bbox.min_vertex();
let size = bbox.size();
let rect = imageproc::rect::Rect::at(min_point.x as i32, min_point.y as i32)
.of_size(size.x as u32, size.y as u32);
imageproc::drawing::draw_hollow_rect_mut(
&mut output_img,
rect,
image::Rgb([255, 0, 0]),
);
}
// Convert to bytes for GUI display
let mut buffer = Vec::new();
let mut cursor = std::io::Cursor::new(&mut buffer);
image::DynamicImage::ImageRgb8(output_img.clone())
.write_to(&mut cursor, image::ImageFormat::Png)?;
// Save to file if output path is specified
if let Some(ref output_path) = output_path {
output_img.save(output_path)?;
}
Some(buffer)
} else {
None
};
Ok((faces_count, processed_image))
}
async fn run_comparison_internal(
image1_path: PathBuf,
image2_path: PathBuf,
threshold: f32,
nms_threshold: f32,
executor_type: ExecutorType,
) -> Result<(usize, usize, f32), Box<dyn std::error::Error + Send + Sync>> {
// Load both images
let img1 = image::open(&image1_path)?.to_rgb8();
let img2 = image::open(&image2_path)?.to_rgb8();
// Convert to ndarray format
let image1_array = img1.as_ndarray()?;
let image2_array = img2.as_ndarray()?;
// Create detection configuration
let config1 = FaceDetectionConfig::default()
.with_threshold(threshold)
.with_nms_threshold(nms_threshold)
.with_input_width(1024)
.with_input_height(1024);
let config2 = FaceDetectionConfig::default()
.with_threshold(threshold)
.with_nms_threshold(nms_threshold)
.with_input_width(1024)
.with_input_height(1024);
// Create detector and embedder, detect faces and generate embeddings
let (faces1, faces2, best_similarity) = match executor_type {
ExecutorType::MnnCpu | ExecutorType::MnnMetal | ExecutorType::MnnCoreML => {
let forward_type = match executor_type {
ExecutorType::MnnCpu => mnn::ForwardType::CPU,
ExecutorType::MnnMetal => mnn::ForwardType::Metal,
ExecutorType::MnnCoreML => mnn::ForwardType::CoreML,
_ => unreachable!(),
};
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_forward_type(forward_type.clone())
.build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
let embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
.with_forward_type(forward_type)
.build()
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
// Detect faces in both images
let faces1 = detector
.detect_faces(image1_array.view(), &config1)
.map_err(|e| format!("Detection failed for image 1: {}", e))?;
let faces2 = detector
.detect_faces(image2_array.view(), &config2)
.map_err(|e| format!("Detection failed for image 2: {}", e))?;
// Extract face crops and generate embeddings
let mut best_similarity = 0.0f32;
for bbox1 in &faces1.bbox {
let crop1 = Self::crop_face_from_image(&img1, bbox1)?;
let crop1_array = ndarray::Array::from_shape_vec(
(1, crop1.height() as usize, crop1.width() as usize, 3),
crop1
.pixels()
.flat_map(|p| [p.0[0], p.0[1], p.0[2]])
.collect(),
)?;
let embedding1 = embedder
.run_models(crop1_array.view())
.map_err(|e| format!("Embedding generation failed: {}", e))?;
for bbox2 in &faces2.bbox {
let crop2 = Self::crop_face_from_image(&img2, bbox2)?;
let crop2_array = ndarray::Array::from_shape_vec(
(1, crop2.height() as usize, crop2.width() as usize, 3),
crop2
.pixels()
.flat_map(|p| [p.0[0], p.0[1], p.0[2]])
.collect(),
)?;
let embedding2 = embedder
.run_models(crop2_array.view())
.map_err(|e| format!("Embedding generation failed: {}", e))?;
let similarity = Self::cosine_similarity(
embedding1.row(0).as_slice().unwrap(),
embedding2.row(0).as_slice().unwrap(),
);
best_similarity = best_similarity.max(similarity);
}
}
(faces1, faces2, best_similarity)
}
ExecutorType::OnnxCpu => {
let mut detector = retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.map_err(|e| format!("Failed to create ONNX detector: {}", e))?
.build()
.map_err(|e| format!("Failed to build ONNX detector: {}", e))?;
let mut embedder = facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
.map_err(|e| format!("Failed to create ONNX embedder: {}", e))?
.build()
.map_err(|e| format!("Failed to build ONNX embedder: {}", e))?;
// Detect faces in both images
let faces1 = detector
.detect_faces(image1_array.view(), &config1)
.map_err(|e| format!("Detection failed for image 1: {}", e))?;
let faces2 = detector
.detect_faces(image2_array.view(), &config2)
.map_err(|e| format!("Detection failed for image 2: {}", e))?;
// Extract face crops and generate embeddings
let mut best_similarity = 0.0f32;
for bbox1 in &faces1.bbox {
let crop1 = Self::crop_face_from_image(&img1, bbox1)?;
let crop1_array = ndarray::Array::from_shape_vec(
(1, crop1.height() as usize, crop1.width() as usize, 3),
crop1
.pixels()
.flat_map(|p| [p.0[0], p.0[1], p.0[2]])
.collect(),
)?;
let embedding1 = embedder
.run_models(crop1_array.view())
.map_err(|e| format!("Embedding generation failed: {}", e))?;
for bbox2 in &faces2.bbox {
let crop2 = Self::crop_face_from_image(&img2, bbox2)?;
let crop2_array = ndarray::Array::from_shape_vec(
(1, crop2.height() as usize, crop2.width() as usize, 3),
crop2
.pixels()
.flat_map(|p| [p.0[0], p.0[1], p.0[2]])
.collect(),
)?;
let embedding2 = embedder
.run_models(crop2_array.view())
.map_err(|e| format!("Embedding generation failed: {}", e))?;
let similarity = Self::cosine_similarity(
embedding1.row(0).as_slice().unwrap(),
embedding2.row(0).as_slice().unwrap(),
);
best_similarity = best_similarity.max(similarity);
}
}
(faces1, faces2, best_similarity)
}
};
Ok((faces1.bbox.len(), faces2.bbox.len(), best_similarity))
}
fn crop_face_from_image(
img: &image::RgbImage,
bbox: &bounding_box::Aabb2<usize>,
) -> Result<image::RgbImage, Box<dyn std::error::Error + Send + Sync>> {
let min_point = bbox.min_vertex();
let size = bbox.size();
let x = min_point.x as u32;
let y = min_point.y as u32;
let width = size.x as u32;
let height = size.y as u32;
// Ensure crop bounds are within image
let img_width = img.width();
let img_height = img.height();
let crop_x = x.min(img_width.saturating_sub(1));
let crop_y = y.min(img_height.saturating_sub(1));
let crop_width = width.min(img_width - crop_x);
let crop_height = height.min(img_height - crop_y);
Ok(image::imageops::crop_imm(img, crop_x, crop_y, crop_width, crop_height).to_image())
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot_product / (norm_a * norm_b)
}
}
}

5
src/gui/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
pub mod app;
pub mod bridge;
pub use app::{FaceDetectorApp, Message, run};
pub use bridge::FaceDetectionBridge;

View File

@@ -1,5 +0,0 @@
// pub struct Image {
// pub width: u32,
// pub height: u32,
// pub data: Vec<u8>,
// }

View File

@@ -1,4 +1,7 @@
pub mod database;
pub mod errors;
pub mod facedet;
pub mod image;
use errors::*;
pub mod faceembed;
pub mod gui;
pub mod ort_ep;
pub use errors::*;

View File

@@ -1,59 +0,0 @@
mod cli;
mod errors;
use detector::facedet::retinaface::FaceDetectionConfig;
use errors::*;
use ndarray_image::*;
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
pub fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter("trace")
.with_thread_ids(true)
.with_thread_names(true)
.with_target(false)
.init();
let args = <cli::Cli as clap::Parser>::parse();
match args.cmd {
cli::SubCommand::Detect(detect) => {
use detector::facedet;
let model = facedet::retinaface::FaceDetection::new_from_bytes(RETINAFACE_MODEL)
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let image = image::open(detect.image).change_context(Error)?;
let image = image.into_rgb8();
let mut array = image
.into_ndarray()
.change_context(errors::Error)
.attach_printable("Failed to convert image to ndarray")?;
let output = model
.detect_faces(
array.clone(),
FaceDetectionConfig::default().with_threshold(detect.threshold),
)
.change_context(errors::Error)
.attach_printable("Failed to detect faces")?;
for bbox in output.bbox {
tracing::info!("Detected face: {:?}", bbox);
use bounding_box::draw::*;
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
}
let v = array.view();
if let Some(output) = detect.output {
let image: image::RgbImage = v
.to_image()
.change_context(errors::Error)
.attach_printable("Failed to convert ndarray to image")?;
image
.save(output)
.change_context(errors::Error)
.attach_printable("Failed to save output image")?;
}
}
cli::SubCommand::List(list) => {
println!("List: {:?}", list);
}
cli::SubCommand::Completions { shell } => {
cli::Cli::completions(shell);
}
}
Ok(())
}

197
src/ort_ep.rs Normal file
View File

@@ -0,0 +1,197 @@
#[cfg(feature = "ort-cuda")]
use ort::execution_providers::CUDAExecutionProvider;
#[cfg(feature = "ort-coreml")]
use ort::execution_providers::CoreMLExecutionProvider;
#[cfg(feature = "ort-directml")]
use ort::execution_providers::DirectMLExecutionProvider;
#[cfg(feature = "ort-openvino")]
use ort::execution_providers::OpenVINOExecutionProvider;
#[cfg(feature = "ort-tvm")]
use ort::execution_providers::TVMExecutionProvider;
#[cfg(feature = "ort-tensorrt")]
use ort::execution_providers::TensorRTExecutionProvider;
use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch};
/// Supported execution providers for ONNX Runtime
#[derive(Debug, Clone)]
pub enum ExecutionProvider {
/// CPU execution provider (always available)
CPU,
/// CoreML execution provider (macOS only)
CoreML,
/// CUDA execution provider (requires cuda feature)
CUDA,
/// TensorRT execution provider (requires tensorrt feature)
TensorRT,
/// TVM execution provider (requires tvm feature)
TVM,
/// OpenVINO execution provider (requires openvino feature)
OpenVINO,
/// DirectML execution provider (Windows only, requires directml feature)
DirectML,
}
impl std::fmt::Display for ExecutionProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExecutionProvider::CPU => write!(f, "CPU"),
ExecutionProvider::CoreML => write!(f, "CoreML"),
ExecutionProvider::CUDA => write!(f, "CUDA"),
ExecutionProvider::TensorRT => write!(f, "TensorRT"),
ExecutionProvider::TVM => write!(f, "TVM"),
ExecutionProvider::OpenVINO => write!(f, "OpenVINO"),
ExecutionProvider::DirectML => write!(f, "DirectML"),
}
}
}
impl std::str::FromStr for ExecutionProvider {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"cpu" => Ok(ExecutionProvider::CPU),
"coreml" => Ok(ExecutionProvider::CoreML),
"cuda" => Ok(ExecutionProvider::CUDA),
"tensorrt" => Ok(ExecutionProvider::TensorRT),
"tvm" => Ok(ExecutionProvider::TVM),
"openvino" => Ok(ExecutionProvider::OpenVINO),
"directml" => Ok(ExecutionProvider::DirectML),
_ => Err(format!("Unknown execution provider: {}", s)),
}
}
}
impl ExecutionProvider {
/// Returns all available execution providers for the current platform and features
pub fn available_providers() -> Vec<ExecutionProvider> {
vec![
ExecutionProvider::CPU,
#[cfg(all(target_os = "macos", feature = "ort-coreml"))]
ExecutionProvider::CoreML,
#[cfg(feature = "ort-cuda")]
ExecutionProvider::CUDA,
#[cfg(feature = "ort-tensorrt")]
ExecutionProvider::TensorRT,
#[cfg(feature = "ort-tvm")]
ExecutionProvider::TVM,
#[cfg(feature = "ort-openvino")]
ExecutionProvider::OpenVINO,
#[cfg(all(target_os = "windows", feature = "ort-directml"))]
ExecutionProvider::DirectML,
]
}
/// Check if this execution provider is available on the current platform
pub fn is_available(&self) -> bool {
match self {
ExecutionProvider::CPU => true,
ExecutionProvider::CoreML => cfg!(target_os = "macos") && cfg!(feature = "ort-coreml"),
ExecutionProvider::CUDA => cfg!(feature = "ort-cuda"),
ExecutionProvider::TensorRT => cfg!(feature = "ort-tensorrt"),
ExecutionProvider::TVM => cfg!(feature = "ort-tvm"),
ExecutionProvider::OpenVINO => cfg!(feature = "ort-openvino"),
ExecutionProvider::DirectML => {
cfg!(target_os = "windows") && cfg!(feature = "ort-directml")
}
}
}
}
impl ExecutionProvider {
pub fn to_dispatch(&self) -> Option<ExecutionProviderDispatch> {
match self {
ExecutionProvider::CPU => Some(CPUExecutionProvider::default().build()),
ExecutionProvider::CoreML => {
#[cfg(target_os = "macos")]
{
#[cfg(feature = "ort-coreml")]
{
use tap::Tap;
Some(
CoreMLExecutionProvider::default()
.with_model_format(
ort::execution_providers::coreml::CoreMLModelFormat::MLProgram,
)
.build(),
)
}
#[cfg(not(feature = "ort-coreml"))]
{
tracing::error!("coreml support not compiled in");
None
}
}
#[cfg(not(target_os = "macos"))]
{
tracing::error!("CoreML is only available on macOS");
None
}
}
ExecutionProvider::CUDA => {
#[cfg(feature = "ort-cuda")]
{
Some(CUDAExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-cuda"))]
{
tracing::error!("CUDA support not compiled in");
None
}
}
ExecutionProvider::TensorRT => {
#[cfg(feature = "ort-tensorrt")]
{
Some(TensorRTExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-tensorrt"))]
{
tracing::error!("TensorRT support not compiled in");
None
}
}
ExecutionProvider::TVM => {
#[cfg(feature = "ort-tvm")]
{
Some(TVMExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-tvm"))]
{
tracing::error!("TVM support not compiled in");
None
}
}
ExecutionProvider::OpenVINO => {
#[cfg(feature = "ort-openvino")]
{
Some(OpenVINOExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-openvino"))]
{
tracing::error!("OpenVINO support not compiled in");
None
}
}
ExecutionProvider::DirectML => {
#[cfg(target_os = "windows")]
{
#[cfg(feature = "ort-directml")]
{
Some(DirectMLExecutionProvider::default().build())
}
#[cfg(not(feature = "ort-directml"))]
{
tracing::error!("DirectML support not compiled in");
None
}
}
#[cfg(not(target_os = "windows"))]
{
tracing::error!("DirectML is only available on Windows");
None
}
}
}
}
}