Compare commits
47 Commits
043a845fc1
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59a3fddc0b | ||
|
|
eb9451aad8 | ||
|
|
c6b3f5279f | ||
|
|
a419a5ac4a | ||
|
|
a340552257 | ||
|
|
aaf34ef74e | ||
|
|
ac8f1d01b4 | ||
|
|
4256c0af74 | ||
|
|
3eec262076 | ||
|
|
c758fd8d41 | ||
|
|
34eaf9348a | ||
|
|
dab7719206 | ||
|
|
4b4d23d1d4 | ||
|
|
aab3d84db0 | ||
|
|
65560825fa | ||
|
|
0a5dbaaadc | ||
|
|
3e14a16739 | ||
|
|
bfa389b497 | ||
|
|
f8122892e0 | ||
|
|
97f64e7e10 | ||
|
|
37adb74adf | ||
|
|
47218fa696 | ||
|
|
61466c9edd | ||
|
|
33798467ba | ||
|
|
3d56db687c | ||
|
|
cd12e97de3 | ||
|
|
bd6520ce5a | ||
|
|
cd9c65ff6b | ||
|
|
cc26391610 | ||
|
|
783320131a | ||
|
|
7fc958b299 | ||
|
|
3aa95a2ef5 | ||
|
|
e7c9c38ed7 | ||
|
|
5a1f4b9ef6 | ||
|
|
087d841959 | ||
|
|
050e937408 | ||
|
|
33afbfc2b8 | ||
|
|
2d2309837f | ||
|
|
f5740dc87f | ||
|
|
3753e399b1 | ||
|
|
d52b69911f | ||
|
|
a3ea01b7b6 | ||
|
|
e60921b099 | ||
|
|
e91ae5b865 | ||
|
|
2c43f657aa | ||
|
|
8d07b0846c | ||
|
|
f7aae32caf |
4
.gitattributes
vendored
Normal file
4
.gitattributes
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
models/facenet.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
models/retinaface.mnn filter=lfs diff=lfs merge=lfs -text
|
||||
models/retinaface.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
models/facenet.mnn filter=lfs diff=lfs merge=lfs -text
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -2,3 +2,6 @@
|
||||
/target
|
||||
.direnv
|
||||
*.jpg
|
||||
face_net.onnx
|
||||
.DS_Store
|
||||
*.cache
|
||||
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "rfcs"]
|
||||
path = rfcs
|
||||
url = git@github.com:aftershootco/rfcs.git
|
||||
4975
Cargo.lock
generated
4975
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
67
Cargo.toml
67
Cargo.toml
@@ -1,14 +1,27 @@
|
||||
[workspace]
|
||||
members = ["ndarray-image", "ndarray-resize", ".", "bounding-box"]
|
||||
members = [
|
||||
"ndarray-image",
|
||||
"ndarray-resize",
|
||||
".",
|
||||
"bounding-box",
|
||||
"ndarray-safetensors",
|
||||
"sqlite3-ndarray-math",
|
||||
"ndcv-bridge",
|
||||
"bbox",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[patch."https://github.com/uttarayan21/mnn-rs"]
|
||||
mnn = { path = "/Users/fs0c131y/Projects/aftershoot/mnn-rs" }
|
||||
[patch.crates-io]
|
||||
linfa = { git = "https://github.com/relf/linfa", branch = "upgrade-ndarray-0.16" }
|
||||
linfa-clustering = { git = "https://github.com/relf/linfa", branch = "upgrade-ndarray-0.16" }
|
||||
|
||||
[workspace.dependencies]
|
||||
divan = { version = "0.1.21" }
|
||||
ndarray-npy = "0.9.1"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
ndarray-image = { path = "ndarray-image" }
|
||||
ndarray-resize = { path = "ndarray-resize" }
|
||||
mnn = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.2.0", features = [
|
||||
@@ -22,6 +35,16 @@ mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0",
|
||||
mnn-sync = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
|
||||
"tracing",
|
||||
], branch = "restructure-tensor-type" }
|
||||
nalgebra = { version = "0.34.0", default-features = false, features = ["std"] }
|
||||
opencv = { version = "0.95.1" }
|
||||
bounding-box = { path = "bounding-box" }
|
||||
bytemuck = "1.23.2"
|
||||
error-stack = "0.5.0"
|
||||
thiserror = "2.0"
|
||||
fast_image_resize = "5.2.0"
|
||||
img-parts = "0.4.0"
|
||||
ndarray = { version = "0.16.1", features = ["rayon"] }
|
||||
num = "0.4"
|
||||
|
||||
[package]
|
||||
name = "detector"
|
||||
@@ -35,12 +58,11 @@ clap_complete = "4.5"
|
||||
error-stack = "0.5"
|
||||
fast_image_resize = "5.2.0"
|
||||
image = "0.25.6"
|
||||
linfa = "0.7.1"
|
||||
nalgebra = "0.33.2"
|
||||
nalgebra = { workspace = true }
|
||||
ndarray = "0.16.1"
|
||||
ndarray-image = { workspace = true }
|
||||
ndarray-resize = { workspace = true }
|
||||
rusqlite = { version = "0.37.0", features = ["modern-full"] }
|
||||
rusqlite = { version = "0.37.0", features = ["functions", "modern-full"] }
|
||||
tap = "1.0.1"
|
||||
thiserror = "2.0"
|
||||
tokio = "1.43.1"
|
||||
@@ -53,6 +75,39 @@ bounding-box = { version = "0.1.0", path = "bounding-box" }
|
||||
color = "0.3.1"
|
||||
itertools = "0.14.0"
|
||||
ordered-float = "5.0.0"
|
||||
ort = { version = "2.0.0-rc.10", default-features = false, features = [
|
||||
"std",
|
||||
"tracing",
|
||||
"ndarray",
|
||||
"cuda",
|
||||
] }
|
||||
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
|
||||
ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" }
|
||||
sqlite3-ndarray-math = { version = "0.1.0", path = "sqlite3-ndarray-math" }
|
||||
|
||||
# GUI dependencies
|
||||
iced = { version = "0.13", features = ["tokio", "image"] }
|
||||
rfd = "0.15"
|
||||
futures = "0.3"
|
||||
imageproc = "0.25"
|
||||
linfa = "0.7.1"
|
||||
linfa-clustering = "0.7.1"
|
||||
|
||||
[profile.release]
|
||||
debug = true
|
||||
|
||||
[features]
|
||||
ort-cuda = []
|
||||
ort-coreml = ["ort/coreml"]
|
||||
ort-tensorrt = ["ort/tensorrt"]
|
||||
ort-tvm = ["ort/tvm"]
|
||||
ort-openvino = ["ort/openvino"]
|
||||
ort-directml = ["ort/directml"]
|
||||
mnn-metal = ["mnn/metal"]
|
||||
mnn-coreml = ["mnn/coreml"]
|
||||
|
||||
default = ["ort-cuda"]
|
||||
|
||||
[[test]]
|
||||
name = "test_bbox_replacement"
|
||||
path = "test_bbox_replacement.rs"
|
||||
|
||||
38
Makefile.toml
Normal file
38
Makefile.toml
Normal file
@@ -0,0 +1,38 @@
|
||||
[tasks.convert]
|
||||
dependencies = ["convert_facenet", "convert_retinaface"]
|
||||
workspace = false
|
||||
|
||||
[tasks.convert_facenet]
|
||||
command = "MNNConvert"
|
||||
args = [
|
||||
"-f",
|
||||
"ONNX",
|
||||
"--modelFile",
|
||||
"models/facenet.onnx",
|
||||
"--MNNModel",
|
||||
"models/facenet.mnn",
|
||||
"--fp16",
|
||||
"--bizCode",
|
||||
"MNN",
|
||||
]
|
||||
workspace = false
|
||||
|
||||
[tasks.convert_retinaface]
|
||||
command = "MNNConvert"
|
||||
args = [
|
||||
"-f",
|
||||
"ONNX",
|
||||
"--modelFile",
|
||||
"models/retinaface.onnx",
|
||||
"--MNNModel",
|
||||
"models/retinaface.mnn",
|
||||
"--fp16",
|
||||
"--bizCode",
|
||||
"MNN",
|
||||
]
|
||||
workspace = false
|
||||
|
||||
[tasks.gui]
|
||||
command = "cargo"
|
||||
args = ["run", "--release", "--bin", "gui"]
|
||||
workspace = false
|
||||
228
README.md
228
README.md
@@ -1,3 +1,227 @@
|
||||
# Face Detection
|
||||
# Face Detection and Embedding
|
||||
|
||||
Rust programs to do face detection and face embedding
|
||||
A high-performance Rust implementation for face detection and face embedding generation using neural networks.
|
||||
|
||||
## Overview
|
||||
|
||||
This project provides a complete face detection and recognition pipeline with the following capabilities:
|
||||
|
||||
- **Face Detection**: Detect faces in images using RetinaFace model
|
||||
- **Face Embedding**: Generate face embeddings using FaceNet model
|
||||
- **Multiple Backends**: Support for both MNN and ONNX runtime execution
|
||||
- **Hardware Acceleration**: Metal, CoreML, and OpenCL support on compatible platforms
|
||||
- **Modular Design**: Workspace architecture with reusable components
|
||||
|
||||
## Features
|
||||
|
||||
- 🔍 **Accurate Face Detection** - Uses RetinaFace model for robust face detection
|
||||
- 🧠 **Face Embeddings** - Generate 512-dimensional face embeddings with FaceNet
|
||||
- ⚡ **High Performance** - Optimized with hardware acceleration (Metal, CoreML)
|
||||
- 🔧 **Flexible Configuration** - Adjustable detection thresholds and NMS parameters
|
||||
- 📦 **Modular Architecture** - Reusable components for image processing and bounding boxes
|
||||
- 🖼️ **Visual Output** - Draw bounding boxes on detected faces
|
||||
|
||||
## Architecture
|
||||
|
||||
The project is organized as a Rust workspace with the following components:
|
||||
|
||||
- **`detector`** - Main face detection and embedding application
|
||||
- **`bounding-box`** - Geometric operations and drawing utilities for bounding boxes
|
||||
- **`ndarray-image`** - Conversion utilities between ndarray and image formats
|
||||
- **`ndarray-resize`** - Fast image resizing operations on ndarray data
|
||||
|
||||
## Models
|
||||
|
||||
The project includes pre-trained neural network models:
|
||||
|
||||
- **RetinaFace** - Face detection model (`.mnn` and `.onnx` formats)
|
||||
- **FaceNet** - Face embedding model (`.mnn` and `.onnx` formats)
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Face Detection
|
||||
|
||||
```bash
|
||||
# Detect faces using MNN backend (default)
|
||||
cargo run --release detect path/to/image.jpg
|
||||
|
||||
# Detect faces using ONNX Runtime backend
|
||||
cargo run --release detect --executor onnx path/to/image.jpg
|
||||
|
||||
# Save output with bounding boxes drawn
|
||||
cargo run --release detect --output detected.jpg path/to/image.jpg
|
||||
|
||||
# Adjust detection sensitivity
|
||||
cargo run --release detect --threshold 0.9 --nms-threshold 0.4 path/to/image.jpg
|
||||
```
|
||||
|
||||
### Face Comparison
|
||||
|
||||
Compare faces between two images by computing and comparing their embeddings:
|
||||
|
||||
```bash
|
||||
# Compare faces in two images
|
||||
cargo run --release compare image1.jpg image2.jpg
|
||||
|
||||
# Compare with custom thresholds
|
||||
cargo run --release compare --threshold 0.9 --nms-threshold 0.4 image1.jpg image2.jpg
|
||||
|
||||
# Use ONNX Runtime backend for comparison
|
||||
cargo run --release compare -p cpu image1.jpg image2.jpg
|
||||
|
||||
# Use MNN with Metal acceleration
|
||||
cargo run --release compare -f metal image1.jpg image2.jpg
|
||||
```
|
||||
|
||||
The compare command will:
|
||||
1. Detect all faces in both images
|
||||
2. Generate embeddings for each detected face
|
||||
3. Compute cosine similarity between all face pairs
|
||||
4. Display similarity scores and the best match
|
||||
5. Provide interpretation of the similarity scores:
|
||||
- **> 0.8**: Very likely the same person
|
||||
- **0.6-0.8**: Possibly the same person
|
||||
- **0.4-0.6**: Unlikely to be the same person
|
||||
- **< 0.4**: Very unlikely to be the same person
|
||||
|
||||
### Backend Selection
|
||||
|
||||
The project supports two inference backends:
|
||||
|
||||
- **MNN Backend** (default): High-performance inference framework with Metal/CoreML support
|
||||
- **ONNX Runtime Backend**: Cross-platform ML inference with broad hardware support
|
||||
|
||||
```bash
|
||||
# Use MNN backend with Metal acceleration (macOS)
|
||||
cargo run --release detect --executor mnn --forward-type metal path/to/image.jpg
|
||||
|
||||
# Use ONNX Runtime backend
|
||||
cargo run --release detect --executor onnx path/to/image.jpg
|
||||
```
|
||||
|
||||
### Command Line Options
|
||||
|
||||
```bash
|
||||
# Face detection with custom parameters
|
||||
cargo run --release detect [OPTIONS] <IMAGE>
|
||||
|
||||
Options:
|
||||
-m, --model <MODEL> Custom model path
|
||||
-M, --model-type <MODEL_TYPE> Model type [default: retina-face]
|
||||
-o, --output <OUTPUT> Output image path
|
||||
-e, --executor <EXECUTOR> Inference backend [mnn, onnx]
|
||||
-f, --forward-type <FORWARD_TYPE> MNN execution backend [default: cpu]
|
||||
-t, --threshold <THRESHOLD> Detection threshold [default: 0.8]
|
||||
-n, --nms-threshold <NMS_THRESHOLD> NMS threshold [default: 0.3]
|
||||
```
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
# Build the project
|
||||
cargo build --release
|
||||
|
||||
# Run face detection on sample image
|
||||
just run
|
||||
# or
|
||||
cargo run --release detect ./1000066593.jpg
|
||||
```
|
||||
|
||||
## Hardware Acceleration
|
||||
|
||||
### MNN Backend
|
||||
|
||||
The MNN backend supports various execution backends:
|
||||
|
||||
- **CPU** - Default, works on all platforms
|
||||
- **Metal** - macOS GPU acceleration
|
||||
- **CoreML** - macOS/iOS neural engine acceleration
|
||||
- **OpenCL** - Cross-platform GPU acceleration
|
||||
|
||||
```bash
|
||||
# Use Metal acceleration on macOS
|
||||
cargo run --release detect --executor mnn --forward-type metal path/to/image.jpg
|
||||
|
||||
# Use CoreML on macOS/iOS
|
||||
cargo run --release detect --executor mnn --forward-type coreml path/to/image.jpg
|
||||
```
|
||||
|
||||
### ONNX Runtime Backend
|
||||
|
||||
The ONNX Runtime backend automatically selects the best available execution provider based on your system configuration.
|
||||
|
||||
## Development
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Rust 2024 edition
|
||||
- MNN runtime (automatically linked)
|
||||
- ONNX runtime (for ONNX backend)
|
||||
|
||||
### Building
|
||||
|
||||
```bash
|
||||
# Standard build
|
||||
cargo build
|
||||
|
||||
# Release build with optimizations
|
||||
cargo build --release
|
||||
|
||||
# Run tests
|
||||
cargo test
|
||||
```
|
||||
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
├── src/
|
||||
│ ├── facedet/ # Face detection modules
|
||||
│ │ ├── mnn/ # MNN backend implementations
|
||||
│ │ ├── ort/ # ONNX Runtime backend implementations
|
||||
│ │ └── postprocess.rs # Shared postprocessing logic
|
||||
│ ├── faceembed/ # Face embedding modules
|
||||
│ │ ├── mnn/ # MNN backend implementations
|
||||
│ │ └── ort/ # ONNX Runtime backend implementations
|
||||
│ ├── cli.rs # Command line interface
|
||||
│ └── main.rs # Application entry point
|
||||
├── models/ # Neural network models (.mnn and .onnx)
|
||||
├── bounding-box/ # Bounding box utilities
|
||||
├── ndarray-image/ # Image conversion utilities
|
||||
└── ndarray-resize/ # Image resizing utilities
|
||||
```
|
||||
|
||||
### Backend Architecture
|
||||
|
||||
The codebase is organized to support multiple inference backends:
|
||||
|
||||
- **Common interfaces**: `FaceDetector` and `FaceEmbedder` traits provide unified APIs
|
||||
- **Shared postprocessing**: Common logic for anchor generation, NMS, and coordinate decoding
|
||||
- **Backend-specific implementations**: Separate modules for MNN and ONNX Runtime
|
||||
- **Modular design**: Easy to add new backends by implementing the common traits
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
|
||||
## Dependencies
|
||||
|
||||
Key dependencies include:
|
||||
|
||||
- **MNN** - High-performance neural network inference framework (MNN backend)
|
||||
- **ONNX Runtime** - Cross-platform ML inference (ORT backend)
|
||||
- **ndarray** - N-dimensional array processing
|
||||
- **image** - Image processing and I/O
|
||||
- **clap** - Command line argument parsing
|
||||
- **bounding-box** - Geometric operations for face detection
|
||||
- **error-stack** - Structured error handling
|
||||
|
||||
### Backend Status
|
||||
|
||||
- ✅ **MNN Backend**: Fully implemented with hardware acceleration support
|
||||
- 🚧 **ONNX Runtime Backend**: Framework implemented, inference logic to be completed
|
||||
|
||||
*Note: The ORT backend currently provides the framework but requires completion of the inference implementation.*
|
||||
|
||||
---
|
||||
|
||||
*Built with Rust for maximum performance and safety in computer vision applications.*
|
||||
|
||||
1
assets/headshots
Symbolic link
1
assets/headshots
Symbolic link
@@ -0,0 +1 @@
|
||||
/Users/fs0c131y/Pictures/test_cases/compressed/HeadshotJpeg
|
||||
13
bbox/Cargo.toml
Normal file
13
bbox/Cargo.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[package]
|
||||
name = "bbox"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
ndarray.workspace = true
|
||||
num = "0.4.3"
|
||||
serde = { workspace = true, features = ["derive"], optional = true }
|
||||
|
||||
[features]
|
||||
serde = ["dep:serde"]
|
||||
default = ["serde"]
|
||||
708
bbox/src/lib.rs
Normal file
708
bbox/src/lib.rs
Normal file
@@ -0,0 +1,708 @@
|
||||
pub mod traits;
|
||||
|
||||
/// A bounding box of co-ordinates whose origin is at the top-left corner.
|
||||
#[derive(
|
||||
Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Hash, serde::Serialize, serde::Deserialize,
|
||||
)]
|
||||
#[non_exhaustive]
|
||||
pub struct BBox<T = f32> {
|
||||
pub x: T,
|
||||
pub y: T,
|
||||
pub width: T,
|
||||
pub height: T,
|
||||
}
|
||||
|
||||
impl<T> From<[T; 4]> for BBox<T> {
|
||||
fn from([x, y, width, height]: [T; 4]) -> Self {
|
||||
Self {
|
||||
x,
|
||||
y,
|
||||
width,
|
||||
height,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Copy> BBox<T> {
|
||||
pub fn new(x: T, y: T, width: T, height: T) -> Self {
|
||||
Self {
|
||||
x,
|
||||
y,
|
||||
width,
|
||||
height,
|
||||
}
|
||||
}
|
||||
|
||||
/// Casts the internal values to another type using [as] keyword
|
||||
pub fn cast<T2>(self) -> BBox<T2>
|
||||
where
|
||||
T: num::cast::AsPrimitive<T2>,
|
||||
T2: Copy + 'static,
|
||||
{
|
||||
BBox {
|
||||
x: self.x.as_(),
|
||||
y: self.y.as_(),
|
||||
width: self.width.as_(),
|
||||
height: self.height.as_(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Clamps all the internal values to the given min and max.
|
||||
pub fn clamp(&self, min: T, max: T) -> Self
|
||||
where
|
||||
T: std::cmp::PartialOrd,
|
||||
{
|
||||
Self {
|
||||
x: num::clamp(self.x, min, max),
|
||||
y: num::clamp(self.y, min, max),
|
||||
width: num::clamp(self.width, min, max),
|
||||
height: num::clamp(self.height, min, max),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clamp_box(&self, bbox: BBox<T>) -> Self
|
||||
where
|
||||
T: std::cmp::PartialOrd,
|
||||
T: num::Zero,
|
||||
T: core::ops::Add<Output = T>,
|
||||
T: core::ops::Sub<Output = T>,
|
||||
{
|
||||
let x1 = num::clamp(self.x1(), bbox.x1(), bbox.x2());
|
||||
let y1 = num::clamp(self.y1(), bbox.y1(), bbox.y2());
|
||||
let x2 = num::clamp(self.x2(), bbox.x1(), bbox.x2());
|
||||
let y2 = num::clamp(self.y2(), bbox.y1(), bbox.y2());
|
||||
Self::new_xyxy(x1, y1, x2, y2)
|
||||
}
|
||||
|
||||
pub fn normalize(&self, width: T, height: T) -> Self
|
||||
where
|
||||
T: core::ops::Div<Output = T> + Copy,
|
||||
{
|
||||
Self {
|
||||
x: self.x / width,
|
||||
y: self.y / height,
|
||||
width: self.width / width,
|
||||
height: self.height / height,
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalize after casting to float
|
||||
pub fn normalize_f64(&self, width: T, height: T) -> BBox<f64>
|
||||
where
|
||||
T: core::ops::Div<Output = T> + Copy,
|
||||
T: num::cast::AsPrimitive<f64>,
|
||||
{
|
||||
BBox {
|
||||
x: self.x.as_() / width.as_(),
|
||||
y: self.y.as_() / height.as_(),
|
||||
width: self.width.as_() / width.as_(),
|
||||
height: self.height.as_() / height.as_(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn denormalize(&self, width: T, height: T) -> Self
|
||||
where
|
||||
T: core::ops::Mul<Output = T> + Copy,
|
||||
{
|
||||
Self {
|
||||
x: self.x * width,
|
||||
y: self.y * height,
|
||||
width: self.width * width,
|
||||
height: self.height * height,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn height(&self) -> T {
|
||||
self.height
|
||||
}
|
||||
|
||||
pub fn width(&self) -> T {
|
||||
self.width
|
||||
}
|
||||
|
||||
pub fn padding(&self, padding: T) -> Self
|
||||
where
|
||||
T: core::ops::Add<Output = T> + core::ops::Sub<Output = T> + Copy,
|
||||
{
|
||||
Self {
|
||||
x: self.x - padding,
|
||||
y: self.y - padding,
|
||||
width: self.width + padding + padding,
|
||||
height: self.height + padding + padding,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn padding_height(&self, padding: T) -> Self
|
||||
where
|
||||
T: core::ops::Add<Output = T> + core::ops::Sub<Output = T> + Copy,
|
||||
{
|
||||
Self {
|
||||
x: self.x,
|
||||
y: self.y - padding,
|
||||
width: self.width,
|
||||
height: self.height + padding + padding,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn padding_width(&self, padding: T) -> Self
|
||||
where
|
||||
T: core::ops::Add<Output = T> + core::ops::Sub<Output = T> + Copy,
|
||||
{
|
||||
Self {
|
||||
x: self.x - padding,
|
||||
y: self.y,
|
||||
width: self.width + padding + padding,
|
||||
height: self.height,
|
||||
}
|
||||
}
|
||||
|
||||
// Enlarge / shrink the bounding box by a factor while
|
||||
// keeping the center point and the aspect ratio fixed
|
||||
pub fn scale(&self, factor: T) -> Self
|
||||
where
|
||||
T: core::ops::Mul<Output = T>,
|
||||
T: core::ops::Sub<Output = T>,
|
||||
T: core::ops::Add<Output = T>,
|
||||
T: core::ops::Div<Output = T>,
|
||||
T: num::One + Copy,
|
||||
{
|
||||
let two = num::one::<T>() + num::one::<T>();
|
||||
let width = self.width * factor;
|
||||
let height = self.height * factor;
|
||||
let width_inc = width - self.width;
|
||||
let height_inc = height - self.height;
|
||||
Self {
|
||||
x: self.x - width_inc / two,
|
||||
y: self.y - height_inc / two,
|
||||
width,
|
||||
height,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scale_x(&self, factor: T) -> Self
|
||||
where
|
||||
T: core::ops::Mul<Output = T>
|
||||
+ core::ops::Sub<Output = T>
|
||||
+ core::ops::Add<Output = T>
|
||||
+ core::ops::Div<Output = T>
|
||||
+ num::One
|
||||
+ Copy,
|
||||
{
|
||||
let two = num::one::<T>() + num::one::<T>();
|
||||
let width = self.width * factor;
|
||||
let width_inc = width - self.width;
|
||||
Self {
|
||||
x: self.x - width_inc / two,
|
||||
y: self.y,
|
||||
width,
|
||||
height: self.height,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scale_y(&self, factor: T) -> Self
|
||||
where
|
||||
T: core::ops::Mul<Output = T>
|
||||
+ core::ops::Sub<Output = T>
|
||||
+ core::ops::Add<Output = T>
|
||||
+ core::ops::Div<Output = T>
|
||||
+ num::One
|
||||
+ Copy,
|
||||
{
|
||||
let two = num::one::<T>() + num::one::<T>();
|
||||
let height = self.height * factor;
|
||||
let height_inc = height - self.height;
|
||||
Self {
|
||||
x: self.x,
|
||||
y: self.y - height_inc / two,
|
||||
width: self.width,
|
||||
height,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn offset(&self, offset: Point<T>) -> Self
|
||||
where
|
||||
T: core::ops::Add<Output = T> + Copy,
|
||||
{
|
||||
Self {
|
||||
x: self.x + offset.x,
|
||||
y: self.y + offset.y,
|
||||
width: self.width,
|
||||
height: self.height,
|
||||
}
|
||||
}
|
||||
|
||||
/// Translate the bounding box by the given offset
|
||||
/// if they are in the same scale
|
||||
pub fn translate(&self, bbox: Self) -> Self
|
||||
where
|
||||
T: core::ops::Add<Output = T> + Copy,
|
||||
{
|
||||
Self {
|
||||
x: self.x + bbox.x,
|
||||
y: self.y + bbox.y,
|
||||
width: self.width,
|
||||
height: self.height,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_top_left(&self, top_left: Point<T>) -> Self {
|
||||
Self {
|
||||
x: top_left.x,
|
||||
y: top_left.y,
|
||||
width: self.width,
|
||||
height: self.height,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn center(&self) -> Point<T>
|
||||
where
|
||||
T: core::ops::Add<Output = T> + core::ops::Div<Output = T> + Copy,
|
||||
T: num::One,
|
||||
{
|
||||
let two = T::one() + T::one();
|
||||
Point::new(self.x + self.width / two, self.y + self.height / two)
|
||||
}
|
||||
|
||||
pub fn area(&self) -> T
|
||||
where
|
||||
T: core::ops::Mul<Output = T> + Copy,
|
||||
{
|
||||
self.width * self.height
|
||||
}
|
||||
|
||||
// Corresponds to self.x1() and self.y1()
|
||||
pub fn top_left(&self) -> Point<T> {
|
||||
Point::new(self.x, self.y)
|
||||
}
|
||||
|
||||
pub fn top_right(&self) -> Point<T>
|
||||
where
|
||||
T: core::ops::Add<Output = T> + Copy,
|
||||
{
|
||||
Point::new(self.x + self.width, self.y)
|
||||
}
|
||||
|
||||
pub fn bottom_left(&self) -> Point<T>
|
||||
where
|
||||
T: core::ops::Add<Output = T> + Copy,
|
||||
{
|
||||
Point::new(self.x, self.y + self.height)
|
||||
}
|
||||
|
||||
// Corresponds to self.x2() and self.y2()
|
||||
pub fn bottom_right(&self) -> Point<T>
|
||||
where
|
||||
T: core::ops::Add<Output = T> + Copy,
|
||||
{
|
||||
Point::new(self.x + self.width, self.y + self.height)
|
||||
}
|
||||
|
||||
pub const fn x1(&self) -> T {
|
||||
self.x
|
||||
}
|
||||
|
||||
pub const fn y1(&self) -> T {
|
||||
self.y
|
||||
}
|
||||
|
||||
pub fn x2(&self) -> T
|
||||
where
|
||||
T: core::ops::Add<Output = T> + Copy,
|
||||
{
|
||||
self.x + self.width
|
||||
}
|
||||
|
||||
pub fn y2(&self) -> T
|
||||
where
|
||||
T: core::ops::Add<Output = T> + Copy,
|
||||
{
|
||||
self.y + self.height
|
||||
}
|
||||
|
||||
pub fn overlap(&self, other: &Self) -> T
|
||||
where
|
||||
T: std::cmp::PartialOrd
|
||||
+ traits::min::Min
|
||||
+ traits::max::Max
|
||||
+ num::Zero
|
||||
+ core::ops::Add<Output = T>
|
||||
+ core::ops::Sub<Output = T>
|
||||
+ core::ops::Mul<Output = T>
|
||||
+ Copy,
|
||||
{
|
||||
let x1 = self.x.max(other.x);
|
||||
let y1 = self.y.max(other.y);
|
||||
let x2 = (self.x + self.width).min(other.x + other.width);
|
||||
let y2 = (self.y + self.height).min(other.y + other.height);
|
||||
let width = (x2 - x1).max(T::zero());
|
||||
let height = (y2 - y1).max(T::zero());
|
||||
width * height
|
||||
}
|
||||
|
||||
pub fn iou(&self, other: &Self) -> T
|
||||
where
|
||||
T: std::cmp::Ord
|
||||
+ num::Zero
|
||||
+ traits::min::Min
|
||||
+ traits::max::Max
|
||||
+ core::ops::Add<Output = T>
|
||||
+ core::ops::Sub<Output = T>
|
||||
+ core::ops::Mul<Output = T>
|
||||
+ core::ops::Div<Output = T>
|
||||
+ Copy,
|
||||
{
|
||||
let overlap = self.overlap(other);
|
||||
let union = self.area() + other.area() - overlap;
|
||||
overlap / union
|
||||
}
|
||||
|
||||
pub fn contains(&self, point: Point<T>) -> bool
|
||||
where
|
||||
T: std::cmp::PartialOrd + core::ops::Add<Output = T> + Copy,
|
||||
{
|
||||
point.x >= self.x
|
||||
&& point.x <= self.x + self.width
|
||||
&& point.y >= self.y
|
||||
&& point.y <= self.y + self.height
|
||||
}
|
||||
|
||||
pub fn contains_bbox(&self, other: Self) -> bool
|
||||
where
|
||||
T: std::cmp::PartialOrd + Copy,
|
||||
T: core::ops::Add<Output = T>,
|
||||
{
|
||||
self.contains(other.top_left())
|
||||
&& self.contains(other.top_right())
|
||||
&& self.contains(other.bottom_left())
|
||||
&& self.contains(other.bottom_right())
|
||||
}
|
||||
|
||||
pub fn new_xywh(x: T, y: T, width: T, height: T) -> Self {
|
||||
Self {
|
||||
x,
|
||||
y,
|
||||
width,
|
||||
height,
|
||||
}
|
||||
}
|
||||
pub fn new_xyxy(x1: T, y1: T, x2: T, y2: T) -> Self
|
||||
where
|
||||
T: core::ops::Sub<Output = T> + Copy,
|
||||
{
|
||||
Self {
|
||||
x: x1,
|
||||
y: y1,
|
||||
width: x2 - x1,
|
||||
height: y2 - y1,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn containing(box1: Self, box2: Self) -> Self
|
||||
where
|
||||
T: traits::min::Min + traits::max::Max + Copy,
|
||||
T: core::ops::Sub<Output = T>,
|
||||
T: core::ops::Add<Output = T>,
|
||||
{
|
||||
let x1 = box1.x.min(box2.x);
|
||||
let y1 = box1.y.min(box2.y);
|
||||
let x2 = box1.x2().max(box2.x2());
|
||||
let y2 = box1.y2().max(box2.y2());
|
||||
Self::new_xyxy(x1, y1, x2, y2)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: core::ops::Sub<Output = T> + Copy> core::ops::Sub<T> for BBox<T> {
|
||||
type Output = BBox<T>;
|
||||
fn sub(self, rhs: T) -> Self::Output {
|
||||
BBox {
|
||||
x: self.x - rhs,
|
||||
y: self.y - rhs,
|
||||
width: self.width - rhs,
|
||||
height: self.height - rhs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: core::ops::Add<Output = T> + Copy> core::ops::Add<T> for BBox<T> {
|
||||
type Output = BBox<T>;
|
||||
fn add(self, rhs: T) -> Self::Output {
|
||||
BBox {
|
||||
x: self.x + rhs,
|
||||
y: self.y + rhs,
|
||||
width: self.width + rhs,
|
||||
height: self.height + rhs,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<T: core::ops::Mul<Output = T> + Copy> core::ops::Mul<T> for BBox<T> {
|
||||
type Output = BBox<T>;
|
||||
fn mul(self, rhs: T) -> Self::Output {
|
||||
BBox {
|
||||
x: self.x * rhs,
|
||||
y: self.y * rhs,
|
||||
width: self.width * rhs,
|
||||
height: self.height * rhs,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<T: core::ops::Div<Output = T> + Copy> core::ops::Div<T> for BBox<T> {
|
||||
type Output = BBox<T>;
|
||||
fn div(self, rhs: T) -> Self::Output {
|
||||
BBox {
|
||||
x: self.x / rhs,
|
||||
y: self.y / rhs,
|
||||
width: self.width / rhs,
|
||||
height: self.height / rhs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> core::ops::Add<BBox<T>> for BBox<T>
|
||||
where
|
||||
T: core::ops::Sub<Output = T>
|
||||
+ core::ops::Add<Output = T>
|
||||
+ traits::min::Min
|
||||
+ traits::max::Max
|
||||
+ Copy,
|
||||
{
|
||||
type Output = BBox<T>;
|
||||
fn add(self, rhs: BBox<T>) -> Self::Output {
|
||||
let x1 = self.x1().min(rhs.x1());
|
||||
let y1 = self.y1().min(rhs.y1());
|
||||
let x2 = self.x2().max(rhs.x2());
|
||||
let y2 = self.y2().max(rhs.y2());
|
||||
BBox::new_xyxy(x1, y1, x2, y2)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bbox_add() {
|
||||
let bbox1: BBox<usize> = BBox::new_xyxy(0, 0, 10, 10);
|
||||
let bbox2: BBox<usize> = BBox::new_xyxy(5, 5, 15, 15);
|
||||
let bbox3: BBox<usize> = bbox1 + bbox2;
|
||||
assert_eq!(bbox3, BBox::new_xyxy(0, 0, 15, 15).cast());
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug, Copy, Clone, serde::Serialize, serde::Deserialize, PartialEq, PartialOrd, Eq, Ord, Hash,
|
||||
)]
|
||||
pub struct Point<T = f32> {
|
||||
x: T,
|
||||
y: T,
|
||||
}
|
||||
|
||||
impl<T> Point<T> {
|
||||
pub const fn new(x: T, y: T) -> Self {
|
||||
Self { x, y }
|
||||
}
|
||||
|
||||
pub const fn x(&self) -> T
|
||||
where
|
||||
T: Copy,
|
||||
{
|
||||
self.x
|
||||
}
|
||||
|
||||
pub const fn y(&self) -> T
|
||||
where
|
||||
T: Copy,
|
||||
{
|
||||
self.y
|
||||
}
|
||||
|
||||
pub fn cast<T2>(&self) -> Point<T2>
|
||||
where
|
||||
T: num::cast::AsPrimitive<T2>,
|
||||
T2: Copy + 'static,
|
||||
{
|
||||
Point {
|
||||
x: self.x.as_(),
|
||||
y: self.y.as_(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: core::ops::Sub<T, Output = T> + Copy> core::ops::Sub<Point<T>> for Point<T> {
|
||||
type Output = Point<T>;
|
||||
fn sub(self, rhs: Point<T>) -> Self::Output {
|
||||
Point {
|
||||
x: self.x - rhs.x,
|
||||
y: self.y - rhs.y,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: core::ops::Add<T, Output = T> + Copy> core::ops::Add<Point<T>> for Point<T> {
|
||||
type Output = Point<T>;
|
||||
fn add(self, rhs: Point<T>) -> Self::Output {
|
||||
Point {
|
||||
x: self.x + rhs.x,
|
||||
y: self.y + rhs.y,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: core::ops::Sub<Output = T> + Copy> Point<T> {
|
||||
/// If both the boxes are in the same scale then make the translation of the origin to the
|
||||
/// other box
|
||||
pub fn with_origin(&self, origin: Self) -> Self {
|
||||
*self - origin
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: core::ops::Add<Output = T> + Copy> Point<T> {
|
||||
pub fn translate(&self, point: Point<T>) -> Self {
|
||||
*self + point
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: num::Zero> BBox<I>
|
||||
where
|
||||
I: num::cast::AsPrimitive<usize>,
|
||||
{
|
||||
pub fn zeros_ndarray_2d<T: num::Zero + Copy>(&self) -> ndarray::Array2<T> {
|
||||
ndarray::Array2::<T>::zeros((self.height.as_(), self.width.as_()))
|
||||
}
|
||||
pub fn zeros_ndarray_3d<T: num::Zero + Copy>(&self, channels: usize) -> ndarray::Array3<T> {
|
||||
ndarray::Array3::<T>::zeros((self.height.as_(), self.width.as_(), channels))
|
||||
}
|
||||
pub fn ones_ndarray_2d<T: num::One + Copy>(&self) -> ndarray::Array2<T> {
|
||||
ndarray::Array2::<T>::ones((self.height.as_(), self.width.as_()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: num::Float> BBox<T> {
|
||||
pub fn round(&self) -> Self {
|
||||
Self {
|
||||
x: self.x.round(),
|
||||
y: self.y.round(),
|
||||
width: self.width.round(),
|
||||
height: self.height.round(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod bbox_clamp_tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
pub fn bbox_test_clamp_box() {
|
||||
let large_box = BBox::new(0, 0, 100, 100);
|
||||
let small_box = BBox::new(10, 10, 20, 20);
|
||||
let clamped = large_box.clamp_box(small_box);
|
||||
assert_eq!(clamped, small_box);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn bbox_test_clamp_box_offset() {
|
||||
let box_a = BBox::new(0, 0, 100, 100);
|
||||
let box_b = BBox::new(-10, -10, 20, 20);
|
||||
let clamped = box_b.clamp_box(box_a);
|
||||
let expected = BBox::new(0, 0, 10, 10);
|
||||
assert_eq!(expected, clamped);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod bbox_padding_tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
pub fn bbox_test_padding() {
|
||||
let bbox = BBox::new(0, 0, 10, 10);
|
||||
let padded = bbox.padding(2);
|
||||
assert_eq!(padded, BBox::new(-2, -2, 14, 14));
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn bbox_test_padding_height() {
|
||||
let bbox = BBox::new(0, 0, 10, 10);
|
||||
let padded = bbox.padding_height(2);
|
||||
assert_eq!(padded, BBox::new(0, -2, 10, 14));
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn bbox_test_padding_width() {
|
||||
let bbox = BBox::new(0, 0, 10, 10);
|
||||
let padded = bbox.padding_width(2);
|
||||
assert_eq!(padded, BBox::new(-2, 0, 14, 10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn bbox_test_clamped_padding() {
|
||||
let bbox = BBox::new(0, 0, 10, 10);
|
||||
let padded = bbox.padding(2);
|
||||
let clamp = BBox::new(0, 0, 12, 12);
|
||||
let clamped = padded.clamp_box(clamp);
|
||||
assert_eq!(clamped, clamp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn bbox_clamp_failure() {
|
||||
let og = BBox::new(475.0, 79.625, 37.0, 282.15);
|
||||
let padded = BBox {
|
||||
x: 471.3,
|
||||
y: 51.412499999999994,
|
||||
width: 40.69999999999999,
|
||||
height: 338.54999999999995,
|
||||
};
|
||||
let clamp = BBox::new(0.0, 0.0, 512.0, 512.0);
|
||||
let sus = padded.clamp_box(clamp);
|
||||
assert!(clamp.contains_bbox(sus));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod bbox_scale_tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
pub fn bbox_test_scale_int() {
|
||||
let bbox = BBox::new(0, 0, 10, 10);
|
||||
let scaled = bbox.scale(2);
|
||||
assert_eq!(scaled, BBox::new(-5, -5, 20, 20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn bbox_test_scale_float() {
|
||||
let bbox = BBox::new(0, 0, 10, 10).cast();
|
||||
let scaled = bbox.scale(1.05); // 5% increase
|
||||
let l = 10.0 * 0.05;
|
||||
assert_eq!(scaled, BBox::new(-l / 2.0, -l / 2.0, 10.0 + l, 10.0 + l));
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn bbox_test_scale_float_negative() {
|
||||
let bbox = BBox::new(0, 0, 10, 10).cast();
|
||||
let scaled = bbox.scale(0.95); // 5% decrease
|
||||
let l = -10.0 * 0.05;
|
||||
assert_eq!(scaled, BBox::new(-l / 2.0, -l / 2.0, 10.0 + l, 10.0 + l));
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn bbox_scale_float() {
|
||||
let bbox = BBox::new_xywh(0, 0, 200, 200);
|
||||
let scaled = bbox.cast::<f64>().scale(1.1).cast::<i32>().clamp(0, 1000);
|
||||
let expected = BBox::new(0, 0, 220, 220);
|
||||
assert_eq!(scaled, expected);
|
||||
}
|
||||
#[test]
|
||||
pub fn add_padding_bbox_example() {
|
||||
// let result = add_padding_bbox(
|
||||
// vec![Rect::new(100, 200, 300, 400)],
|
||||
// (0.1, 0.1),
|
||||
// (1000, 1000),
|
||||
// );
|
||||
// assert_eq!(result[0], Rect::new(70, 160, 360, 480));
|
||||
let bbox = BBox::new(100, 200, 300, 400);
|
||||
let scaled = bbox.cast::<f64>().scale(1.2).cast::<i32>().clamp(0, 1000);
|
||||
assert_eq!(bbox, BBox::new(100, 200, 300, 400));
|
||||
assert_eq!(scaled, BBox::new(70, 160, 360, 480));
|
||||
}
|
||||
#[test]
|
||||
pub fn scale_bboxes() {
|
||||
// let result = scale_bboxes(Rect::new(100, 200, 300, 400), (1000, 1000), (500, 500));
|
||||
// assert_eq!(result[0], Rect::new(200, 400, 600, 800));
|
||||
let bbox = BBox::new(100, 200, 300, 400);
|
||||
let scaled = bbox.scale(2);
|
||||
assert_eq!(scaled, BBox::new(200, 400, 600, 800));
|
||||
}
|
||||
}
|
||||
2
bbox/src/traits.rs
Normal file
2
bbox/src/traits.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod max;
|
||||
pub mod min;
|
||||
27
bbox/src/traits/max.rs
Normal file
27
bbox/src/traits/max.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
pub trait Max: Sized + Copy {
|
||||
fn max(self, other: Self) -> Self;
|
||||
}
|
||||
|
||||
macro_rules! impl_max {
|
||||
($($t:ty),*) => {
|
||||
$(
|
||||
impl Max for $t {
|
||||
fn max(self, other: Self) -> Self {
|
||||
Ord::max(self, other)
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
(float $($t:ty),*) => {
|
||||
$(
|
||||
impl Max for $t {
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
impl_max!(usize, u8, u16, u32, u64, u128, isize, i8, i16, i32, i64, i128);
|
||||
impl_max!(float f32, f64);
|
||||
27
bbox/src/traits/min.rs
Normal file
27
bbox/src/traits/min.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
pub trait Min: Sized + Copy {
|
||||
fn min(self, other: Self) -> Self;
|
||||
}
|
||||
|
||||
macro_rules! impl_min {
|
||||
($($t:ty),*) => {
|
||||
$(
|
||||
impl Min for $t {
|
||||
fn min(self, other: Self) -> Self {
|
||||
Ord::min(self, other)
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
(float $($t:ty),*) => {
|
||||
$(
|
||||
impl Min for $t {
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
impl_min!(usize, u8, u16, u32, u64, u128, isize, i8, i16, i32, i64, i128);
|
||||
impl_min!(float f32, f64);
|
||||
@@ -6,12 +6,16 @@ edition = "2024"
|
||||
[dependencies]
|
||||
color = "0.3.1"
|
||||
itertools = "0.14.0"
|
||||
nalgebra = "0.33.2"
|
||||
nalgebra = { workspace = true }
|
||||
ndarray = { version = "0.16.1", optional = true }
|
||||
num = "0.4.3"
|
||||
ordered-float = "5.0.0"
|
||||
simba = "0.9.0"
|
||||
thiserror = "2.0.12"
|
||||
tracing = { version = "0.1.41", optional = true, default-features = false }
|
||||
|
||||
[features]
|
||||
ndarray = ["dep:ndarray"]
|
||||
default = ["ndarray"]
|
||||
tracing = ["dep:tracing"]
|
||||
|
||||
default = ["ndarray", "tracing"]
|
||||
|
||||
@@ -4,11 +4,11 @@ pub use color::Rgba8;
|
||||
use ndarray::{Array1, Array3, ArrayViewMut3};
|
||||
|
||||
pub trait Draw<T> {
|
||||
fn draw(&mut self, item: T, color: color::Rgba8, thickness: usize);
|
||||
fn draw(&mut self, item: &T, color: color::Rgba8, thickness: usize);
|
||||
}
|
||||
|
||||
impl Draw<Aabb2<usize>> for Array3<u8> {
|
||||
fn draw(&mut self, item: Aabb2<usize>, color: color::Rgba8, thickness: usize) {
|
||||
fn draw(&mut self, item: &Aabb2<usize>, color: color::Rgba8, thickness: usize) {
|
||||
item.draw(self, color, thickness)
|
||||
}
|
||||
}
|
||||
@@ -65,8 +65,9 @@ impl Drawable<Array3<u8>> for Aabb2<usize> {
|
||||
pixel.assign(&color);
|
||||
})
|
||||
})
|
||||
.inspect_err(|e| {
|
||||
dbg!(e);
|
||||
.inspect_err(|_e| {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::error!("{_e}")
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
|
||||
@@ -2,9 +2,38 @@ pub mod draw;
|
||||
pub mod nms;
|
||||
pub mod roi;
|
||||
|
||||
use nalgebra::{Point, Point2, Point3, SVector, SimdPartialOrd, SimdValue};
|
||||
pub trait Num: num::Num + Copy + core::fmt::Debug + 'static {}
|
||||
impl<T: num::Num + Copy + core::fmt::Debug + 'static> Num for T {}
|
||||
use nalgebra::{Point, Point2, SVector, Vector2};
|
||||
pub trait Num:
|
||||
num::Num
|
||||
+ core::ops::AddAssign
|
||||
+ core::ops::SubAssign
|
||||
+ core::ops::MulAssign
|
||||
+ core::ops::DivAssign
|
||||
+ core::cmp::PartialOrd
|
||||
+ core::cmp::PartialEq
|
||||
+ nalgebra::SimdPartialOrd
|
||||
+ nalgebra::SimdValue
|
||||
+ Copy
|
||||
+ core::fmt::Debug
|
||||
+ 'static
|
||||
{
|
||||
}
|
||||
impl<
|
||||
T: num::Num
|
||||
+ core::ops::AddAssign
|
||||
+ core::ops::SubAssign
|
||||
+ core::ops::MulAssign
|
||||
+ core::ops::DivAssign
|
||||
+ core::cmp::PartialOrd
|
||||
+ core::cmp::PartialEq
|
||||
+ nalgebra::SimdPartialOrd
|
||||
+ nalgebra::SimdValue
|
||||
+ Copy
|
||||
+ core::fmt::Debug
|
||||
+ 'static,
|
||||
> Num for T
|
||||
{
|
||||
}
|
||||
|
||||
/// An axis aligned bounding box in `D` dimensions, defined by the minimum vertex and a size vector.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
@@ -20,16 +49,27 @@ pub type Aabb2<T> = AxisAlignedBoundingBox<T, 2>;
|
||||
pub type Aabb3<T> = AxisAlignedBoundingBox<T, 3>;
|
||||
|
||||
impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||
pub fn new(point: Point<T, D>, size: SVector<T, D>) -> Self {
|
||||
// Panics if max < min
|
||||
pub fn new(min_point: Point<T, D>, max_point: Point<T, D>) -> Self {
|
||||
if max_point >= min_point {
|
||||
Self::from_min_max_vertices(min_point, max_point)
|
||||
} else {
|
||||
panic!("max_point must be greater than or equal to min_point");
|
||||
}
|
||||
}
|
||||
pub fn try_new(min_point: Point<T, D>, max_point: Point<T, D>) -> Option<Self> {
|
||||
if max_point < min_point {
|
||||
return None;
|
||||
}
|
||||
Some(Self::from_min_max_vertices(min_point, max_point))
|
||||
}
|
||||
pub fn new_point_size(point: Point<T, D>, size: SVector<T, D>) -> Self {
|
||||
Self { point, size }
|
||||
}
|
||||
|
||||
pub fn from_min_max_vertices(point1: Point<T, D>, point2: Point<T, D>) -> Self
|
||||
where
|
||||
T: core::ops::SubAssign,
|
||||
{
|
||||
let size = point2 - point1;
|
||||
Self::new(point1, SVector::from(size))
|
||||
pub fn from_min_max_vertices(min: Point<T, D>, max: Point<T, D>) -> Self {
|
||||
let size = max - min;
|
||||
Self::new_point_size(min, SVector::from(size))
|
||||
}
|
||||
|
||||
/// Only considers the points closest and furthest from origin
|
||||
@@ -123,6 +163,21 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scale_uniform(self, scalar: T) -> Self
|
||||
where
|
||||
T: core::ops::MulAssign,
|
||||
T: core::ops::DivAssign,
|
||||
T: core::ops::SubAssign,
|
||||
{
|
||||
let two = T::one() + T::one();
|
||||
let new_size = self.size * scalar;
|
||||
let new_point = self.point.coords - (new_size - self.size) / two;
|
||||
Self {
|
||||
point: Point::from(new_point),
|
||||
size: new_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contains_bbox(&self, other: &Self) -> bool
|
||||
where
|
||||
T: core::ops::AddAssign,
|
||||
@@ -151,7 +206,21 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||
self.intersection(other)
|
||||
}
|
||||
|
||||
pub fn union(&self, other: &Self) -> Self
|
||||
pub fn component_clamp(&self, min: T, max: T) -> Self
|
||||
where
|
||||
T: PartialOrd,
|
||||
{
|
||||
let mut this = *self;
|
||||
this.point.iter_mut().for_each(|x| {
|
||||
*x = nalgebra::clamp(*x, min, max);
|
||||
});
|
||||
this.size.iter_mut().for_each(|x| {
|
||||
*x = nalgebra::clamp(*x, min, max);
|
||||
});
|
||||
this
|
||||
}
|
||||
|
||||
pub fn merge(&self, other: &Self) -> Self
|
||||
where
|
||||
T: core::ops::AddAssign,
|
||||
T: core::ops::SubAssign,
|
||||
@@ -159,13 +228,24 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||
T: nalgebra::SimdValue,
|
||||
T: nalgebra::SimdPartialOrd,
|
||||
{
|
||||
let self_min = self.min_vertex();
|
||||
let self_max = self.max_vertex();
|
||||
let other_min = other.min_vertex();
|
||||
let other_max = other.max_vertex();
|
||||
let min = self_min.inf(&other_min);
|
||||
let max = self_max.sup(&other_max);
|
||||
Self::from_min_max_vertices(min, max)
|
||||
let min = self.min_vertex().inf(&other.min_vertex());
|
||||
let max = self.min_vertex().sup(&other.max_vertex());
|
||||
Self::new(min, max)
|
||||
}
|
||||
|
||||
pub fn union(&self, other: &Self) -> T
|
||||
where
|
||||
T: core::ops::AddAssign,
|
||||
T: core::ops::SubAssign,
|
||||
T: core::ops::MulAssign,
|
||||
T: PartialOrd,
|
||||
T: nalgebra::SimdValue,
|
||||
T: nalgebra::SimdPartialOrd,
|
||||
{
|
||||
self.measure() + other.measure()
|
||||
- Self::intersection(self, other)
|
||||
.map(|x| x.measure())
|
||||
.unwrap_or(T::zero())
|
||||
}
|
||||
|
||||
pub fn intersection(&self, other: &Self) -> Option<Self>
|
||||
@@ -176,21 +256,9 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||
T: nalgebra::SimdPartialOrd,
|
||||
T: nalgebra::SimdValue,
|
||||
{
|
||||
let self_min = self.min_vertex();
|
||||
let self_max = self.max_vertex();
|
||||
let other_min = other.min_vertex();
|
||||
let other_max = other.max_vertex();
|
||||
|
||||
if self_max < other_min || other_max < self_min {
|
||||
return None; // No intersection
|
||||
}
|
||||
|
||||
let min = self_min.sup(&other_min);
|
||||
let max = self_max.inf(&other_max);
|
||||
Some(Self::from_min_max_vertices(
|
||||
Point::from(min),
|
||||
Point::from(max),
|
||||
))
|
||||
let inter_min = self.min_vertex().sup(&other.min_vertex());
|
||||
let inter_max = self.max_vertex().inf(&other.max_vertex());
|
||||
Self::try_new(inter_min, inter_max)
|
||||
}
|
||||
|
||||
pub fn denormalize(&self, factor: nalgebra::SVector<T, D>) -> Self
|
||||
@@ -217,15 +285,17 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||
})
|
||||
}
|
||||
|
||||
// pub fn as_<T2>(&self) -> Option<Aabb<T2, D>>
|
||||
// where
|
||||
// T2: Num + simba::scalar::SubsetOf<T>,
|
||||
// {
|
||||
// Some(Aabb {
|
||||
// point: Point::from(self.point.coords.as_()),
|
||||
// size: self.size.as_(),
|
||||
// })
|
||||
// }
|
||||
pub fn as_<T2>(&self) -> Aabb<T2, D>
|
||||
where
|
||||
T2: Num,
|
||||
T: num::cast::AsPrimitive<T2>,
|
||||
{
|
||||
Aabb {
|
||||
point: Point::from(self.point.coords.map(|x| x.as_())),
|
||||
size: self.size.map(|x| x.as_()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn measure(&self) -> T
|
||||
where
|
||||
T: core::ops::MulAssign,
|
||||
@@ -233,7 +303,7 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||
self.size.product()
|
||||
}
|
||||
|
||||
pub fn iou(&self, other: &Self) -> Option<T>
|
||||
pub fn iou(&self, other: &Self) -> T
|
||||
where
|
||||
T: core::ops::AddAssign,
|
||||
T: core::ops::SubAssign,
|
||||
@@ -242,9 +312,19 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
|
||||
T: nalgebra::SimdValue,
|
||||
T: core::ops::MulAssign,
|
||||
{
|
||||
let intersection = self.intersection(other)?;
|
||||
let union = self.union(other);
|
||||
Some(intersection.measure() / union.measure())
|
||||
let lhs_min = self.min_vertex();
|
||||
let lhs_max = self.max_vertex();
|
||||
let rhs_min = other.min_vertex();
|
||||
let rhs_max = other.max_vertex();
|
||||
|
||||
let inter_min = lhs_min.sup(&rhs_min);
|
||||
let inter_max = lhs_max.inf(&rhs_max);
|
||||
if inter_max >= inter_min {
|
||||
let intersection = Aabb::new(inter_min, inter_max).measure();
|
||||
intersection / (self.measure() + other.measure() - intersection)
|
||||
} else {
|
||||
return T::zero();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,15 +335,15 @@ impl<T: Num> Aabb2<T> {
|
||||
{
|
||||
let point1 = Point2::new(x1, y1);
|
||||
let point2 = Point2::new(x2, y2);
|
||||
Self::from_min_max_vertices(point1, point2)
|
||||
Self::new(point1, point2)
|
||||
}
|
||||
pub fn new_2d(point1: Point2<T>, point2: Point2<T>) -> Self
|
||||
where
|
||||
T: core::ops::SubAssign,
|
||||
{
|
||||
let size = point2.coords - point1.coords;
|
||||
Self::new(point1, SVector::from(size))
|
||||
|
||||
pub fn from_xywh(x: T, y: T, w: T, h: T) -> Self {
|
||||
let point = Point2::new(x, y);
|
||||
let size = Vector2::new(w, h);
|
||||
Self::new_point_size(point, size)
|
||||
}
|
||||
|
||||
pub fn x1y1(&self) -> Point2<T> {
|
||||
self.point
|
||||
}
|
||||
@@ -327,14 +407,6 @@ impl<T: Num> Aabb2<T> {
|
||||
}
|
||||
|
||||
impl<T: Num> Aabb3<T> {
|
||||
pub fn new_3d(point1: Point3<T>, point2: Point3<T>) -> Self
|
||||
where
|
||||
T: core::ops::SubAssign,
|
||||
{
|
||||
let size = point2.coords - point1.coords;
|
||||
Self::new(point1, SVector::from(size))
|
||||
}
|
||||
|
||||
pub fn volume(&self) -> T
|
||||
where
|
||||
T: core::ops::MulAssign,
|
||||
@@ -343,124 +415,125 @@ impl<T: Num> Aabb3<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bbox_new() {
|
||||
use nalgebra::{Point2, Vector2};
|
||||
#[cfg(test)]
|
||||
mod boudning_box_tests {
|
||||
use super::*;
|
||||
use nalgebra::*;
|
||||
|
||||
#[test]
|
||||
fn test_bbox_new() {
|
||||
let point1 = Point2::new(1.0, 2.0);
|
||||
let point2 = Point2::new(4.0, 6.0);
|
||||
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
|
||||
let bbox = AxisAlignedBoundingBox::new(point1, point2);
|
||||
|
||||
assert_eq!(bbox.min_vertex(), point1);
|
||||
assert_eq!(bbox.size(), Vector2::new(3.0, 4.0));
|
||||
assert_eq!(bbox.center(), Point2::new(2.5, 4.0));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_center_2d() {
|
||||
use nalgebra::{Point2, Vector2};
|
||||
#[test]
|
||||
fn test_intersection_and_merge() {
|
||||
let point1 = Point2::new(1, 5);
|
||||
let point2 = Point2::new(3, 2);
|
||||
let size1 = Vector2::new(3, 4);
|
||||
let size2 = Vector2::new(1, 3);
|
||||
|
||||
let this = Aabb2::new_point_size(point1, size1);
|
||||
let other = Aabb2::new_point_size(point2, size2);
|
||||
let inter = this.intersection(&other);
|
||||
let merged = this.merge(&other);
|
||||
assert_ne!(inter, Some(merged))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_center_2d() {
|
||||
let point = Point2::new(1.0, 2.0);
|
||||
let size = Vector2::new(3.0, 4.0);
|
||||
let bbox = AxisAlignedBoundingBox::new(point, size);
|
||||
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
|
||||
|
||||
assert_eq!(bbox.min_vertex(), point);
|
||||
assert_eq!(bbox.size(), size);
|
||||
assert_eq!(bbox.center(), Point2::new(2.5, 4.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_center_3d() {
|
||||
use nalgebra::{Point3, Vector3};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_center_3d() {
|
||||
let point = Point3::new(1.0, 2.0, 3.0);
|
||||
let size = Vector3::new(4.0, 5.0, 6.0);
|
||||
let bbox = AxisAlignedBoundingBox::new(point, size);
|
||||
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
|
||||
|
||||
assert_eq!(bbox.min_vertex(), point);
|
||||
assert_eq!(bbox.size(), size);
|
||||
assert_eq!(bbox.center(), Point3::new(3.0, 4.5, 6.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_padding_2d() {
|
||||
use nalgebra::{Point2, Vector2};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_padding_2d() {
|
||||
let point = Point2::new(1.0, 2.0);
|
||||
let size = Vector2::new(3.0, 4.0);
|
||||
let bbox = AxisAlignedBoundingBox::new(point, size);
|
||||
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
|
||||
|
||||
let padded_bbox = bbox.padding(1.0);
|
||||
assert_eq!(padded_bbox.min_vertex(), Point2::new(0.5, 1.5));
|
||||
assert_eq!(padded_bbox.size(), Vector2::new(4.0, 5.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_scaling_2d() {
|
||||
use nalgebra::{Point2, Vector2};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_scaling_2d() {
|
||||
let point = Point2::new(1.0, 1.0);
|
||||
let size = Vector2::new(3.0, 4.0);
|
||||
let bbox = AxisAlignedBoundingBox::new(point, size);
|
||||
let bbox = AxisAlignedBoundingBox::new_point_size(point, size);
|
||||
|
||||
let padded_bbox = bbox.scale(Vector2::new(2.0, 2.0));
|
||||
assert_eq!(padded_bbox.min_vertex(), Point2::new(-2.0, -3.0));
|
||||
assert_eq!(padded_bbox.size(), Vector2::new(6.0, 8.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_contains_2d() {
|
||||
use nalgebra::Point2;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_contains_2d() {
|
||||
let point1 = Point2::new(1.0, 2.0);
|
||||
let point2 = Point2::new(4.0, 6.0);
|
||||
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
|
||||
let bbox = AxisAlignedBoundingBox::new(point1, point2);
|
||||
|
||||
assert!(bbox.contains_point(&Point2::new(2.0, 3.0)));
|
||||
assert!(!bbox.contains_point(&Point2::new(5.0, 7.0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_union_2d() {
|
||||
use nalgebra::{Point2, Vector2};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_union_2d() {
|
||||
let point1 = Point2::new(1.0, 2.0);
|
||||
let point2 = Point2::new(4.0, 6.0);
|
||||
let bbox1 = AxisAlignedBoundingBox::new_2d(point1, point2);
|
||||
let bbox1 = AxisAlignedBoundingBox::new(point1, point2);
|
||||
|
||||
let point3 = Point2::new(3.0, 5.0);
|
||||
let point4 = Point2::new(7.0, 8.0);
|
||||
let bbox2 = AxisAlignedBoundingBox::new_2d(point3, point4);
|
||||
let bbox2 = AxisAlignedBoundingBox::new(point3, point4);
|
||||
|
||||
let union_bbox = bbox1.union(&bbox2);
|
||||
let union_bbox = bbox1.merge(&bbox2);
|
||||
assert_eq!(union_bbox.min_vertex(), Point2::new(1.0, 2.0));
|
||||
assert_eq!(union_bbox.size(), Vector2::new(6.0, 6.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_intersection_2d() {
|
||||
use nalgebra::{Point2, Vector2};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_intersection_2d() {
|
||||
let point1 = Point2::new(1.0, 2.0);
|
||||
let point2 = Point2::new(4.0, 6.0);
|
||||
let bbox1 = AxisAlignedBoundingBox::new_2d(point1, point2);
|
||||
let bbox1 = AxisAlignedBoundingBox::new(point1, point2);
|
||||
|
||||
let point3 = Point2::new(3.0, 5.0);
|
||||
let point4 = Point2::new(5.0, 7.0);
|
||||
let bbox2 = AxisAlignedBoundingBox::new_2d(point3, point4);
|
||||
let bbox2 = AxisAlignedBoundingBox::new(point3, point4);
|
||||
|
||||
let intersection_bbox = bbox1.intersection(&bbox2).unwrap();
|
||||
assert_eq!(intersection_bbox.min_vertex(), Point2::new(3.0, 5.0));
|
||||
assert_eq!(intersection_bbox.size(), Vector2::new(1.0, 1.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_contains_point() {
|
||||
use nalgebra::Point2;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_contains_point() {
|
||||
let point1 = Point2::new(2, 3);
|
||||
let point2 = Point2::new(5, 4);
|
||||
let bbox = AxisAlignedBoundingBox::new_2d(point1, point2);
|
||||
let bbox = AxisAlignedBoundingBox::new(point1, point2);
|
||||
use itertools::Itertools;
|
||||
for (i, j) in (0..=10).cartesian_product(0..=10) {
|
||||
if bbox.contains_point(&Point2::new(i, j)) {
|
||||
@@ -479,10 +552,10 @@ fn test_bounding_box_contains_point() {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box_clamp_box_2d() {
|
||||
#[test]
|
||||
fn test_bounding_box_clamp_box_2d() {
|
||||
let bbox1 = Aabb2::from_x1y1x2y2(1, 1, 4, 4);
|
||||
let bbox2 = Aabb2::from_x1y1x2y2(2, 2, 3, 3);
|
||||
let clamped = bbox2.clamp(&bbox1).unwrap();
|
||||
@@ -495,4 +568,63 @@ fn test_bounding_box_clamp_box_2d() {
|
||||
let clamped = bbox1.clamp(&bbox2).unwrap();
|
||||
let expected = Aabb2::from_x1y1x2y2(5, 5, 7, 7);
|
||||
assert_eq!(clamped, expected)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iou_identical_boxes() {
|
||||
let a = Aabb2::from_x1y1x2y2(1.0, 2.0, 4.0, 6.0);
|
||||
let b = Aabb2::from_x1y1x2y2(1.0, 2.0, 4.0, 6.0);
|
||||
assert_eq!(a.iou(&b), 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iou_non_overlapping_boxes() {
|
||||
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
|
||||
let b = Aabb2::from_x1y1x2y2(2.0, 2.0, 3.0, 3.0);
|
||||
assert_eq!(a.iou(&b), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iou_partial_overlap() {
|
||||
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 2.0, 2.0);
|
||||
let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 3.0, 3.0);
|
||||
// Intersection area = 1, Union area = 7
|
||||
assert!((a.iou(&b) - 1.0 / 7.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iou_one_inside_another() {
|
||||
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 4.0, 4.0);
|
||||
let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 3.0, 3.0);
|
||||
// Intersection area = 4, Union area = 16
|
||||
assert!((a.iou(&b) - 0.25).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iou_edge_touching() {
|
||||
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
|
||||
let b = Aabb2::from_x1y1x2y2(1.0, 0.0, 2.0, 1.0);
|
||||
assert_eq!(a.iou(&b), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iou_corner_touching() {
|
||||
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
|
||||
let b = Aabb2::from_x1y1x2y2(1.0, 1.0, 2.0, 2.0);
|
||||
assert_eq!(a.iou(&b), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iou_zero_area_box() {
|
||||
let a = Aabb2::from_x1y1x2y2(0.0, 0.0, 0.0, 0.0);
|
||||
let b = Aabb2::from_x1y1x2y2(0.0, 0.0, 1.0, 1.0);
|
||||
assert_eq!(a.iou(&b), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_specific_values() {
|
||||
let box1 = Aabb2::from_xywh(0.69482, 0.6716774, 0.07493961, 0.14968264);
|
||||
let box2 = Aabb2::from_xywh(0.41546485, 0.70290875, 0.06197411, 0.08818436);
|
||||
assert!(box1.iou(&box2) >= 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
use std::collections::HashSet;
|
||||
use std::collections::{HashSet, VecDeque};
|
||||
|
||||
use itertools::Itertools;
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
|
||||
pub enum NmsError {
|
||||
#[error("Boxes and scores length mismatch (boxes: {boxes}, scores: {scores})")]
|
||||
BoxesAndScoresLengthMismatch { boxes: usize, scores: usize },
|
||||
}
|
||||
|
||||
use crate::*;
|
||||
/// Apply Non-Maximum Suppression to a set of bounding boxes.
|
||||
@@ -18,10 +25,11 @@ pub fn nms<T>(
|
||||
scores: &[T],
|
||||
score_threshold: T,
|
||||
nms_threshold: T,
|
||||
) -> HashSet<usize>
|
||||
) -> Result<HashSet<usize>, NmsError>
|
||||
where
|
||||
T: Num
|
||||
+ num::Float
|
||||
+ ordered_float::FloatCore
|
||||
+ core::ops::Neg<Output = T>
|
||||
+ core::iter::Product<T>
|
||||
+ core::ops::AddAssign
|
||||
+ core::ops::SubAssign
|
||||
@@ -29,56 +37,37 @@ where
|
||||
+ nalgebra::SimdValue
|
||||
+ nalgebra::SimdPartialOrd,
|
||||
{
|
||||
use itertools::Itertools;
|
||||
|
||||
// Create vector of (index, box, score) tuples for boxes with scores above threshold
|
||||
let mut indexed_boxes: Vec<(usize, &Aabb2<T>, &T)> = boxes
|
||||
if boxes.len() != scores.len() {
|
||||
return Err(NmsError::BoxesAndScoresLengthMismatch {
|
||||
boxes: boxes.len(),
|
||||
scores: scores.len(),
|
||||
});
|
||||
}
|
||||
let mut combined: VecDeque<(usize, Aabb2<T>, T, bool)> = boxes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.zip(scores.iter())
|
||||
.zip(scores)
|
||||
.filter_map(|((idx, bbox), score)| {
|
||||
if *score >= score_threshold {
|
||||
Some((idx, bbox, score))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
(*score > score_threshold).then_some((idx, *bbox, *score, true))
|
||||
})
|
||||
.sorted_by_cached_key(|(_, _, score, _)| -ordered_float::OrderedFloat(*score))
|
||||
.collect();
|
||||
|
||||
// Sort by score in descending order
|
||||
indexed_boxes.sort_by(|(_, _, score_a), (_, _, score_b)| {
|
||||
score_b
|
||||
.partial_cmp(score_a)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let mut keep_indices = HashSet::new();
|
||||
let mut suppressed = HashSet::new();
|
||||
|
||||
for (i, (idx_i, bbox_i, _)) in indexed_boxes.iter().enumerate() {
|
||||
// Skip if this box is already suppressed
|
||||
if suppressed.contains(idx_i) {
|
||||
for i in 0..combined.len() {
|
||||
let first = combined[i];
|
||||
if first.3 == false {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Keep this box
|
||||
keep_indices.insert(*idx_i);
|
||||
|
||||
// Compare with remaining boxes
|
||||
for (idx_j, bbox_j, _) in indexed_boxes.iter().skip(i + 1) {
|
||||
// Skip if this box is already suppressed
|
||||
if suppressed.contains(idx_j) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Calculate IoU and suppress if above threshold
|
||||
if let Some(iou) = bbox_i.iou(bbox_j) {
|
||||
if iou >= nms_threshold {
|
||||
suppressed.insert(*idx_j);
|
||||
}
|
||||
let bbox = first.1;
|
||||
for item in combined.iter_mut().skip(i + 1) {
|
||||
if bbox.iou(&item.1) > nms_threshold {
|
||||
item.3 = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
keep_indices
|
||||
Ok(combined
|
||||
.into_iter()
|
||||
.filter_map(|(idx, _, _, keep)| keep.then_some(idx))
|
||||
.collect())
|
||||
}
|
||||
|
||||
@@ -5,10 +5,17 @@ pub trait Roi<'a, Output> {
|
||||
type Error;
|
||||
fn roi(&'a self, aabb: Aabb2<usize>) -> Result<Output, Self::Error>;
|
||||
}
|
||||
|
||||
pub trait RoiMut<'a, Output> {
|
||||
type Error;
|
||||
fn roi_mut(&'a mut self, aabb: Aabb2<usize>) -> Result<Output, Self::Error>;
|
||||
}
|
||||
|
||||
pub trait MultiRoi<'a, Output> {
|
||||
type Error;
|
||||
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Output, Self::Error>;
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug, Copy, Clone)]
|
||||
pub enum RoiError {
|
||||
#[error("Region of intereset is out of bounds")]
|
||||
@@ -36,7 +43,7 @@ impl<'a, T: Num> RoiMut<'a, ArrayViewMut3<'a, T>> for Array3<T> {
|
||||
let x2 = aabb.x2();
|
||||
let y1 = aabb.y1();
|
||||
let y2 = aabb.y2();
|
||||
if x1 >= x2 || y1 >= y2 || x2 > self.shape()[1] || y2 > self.shape()[0] {
|
||||
if x1 > x2 || y1 > y2 || x2 > self.shape()[1] || y2 > self.shape()[0] {
|
||||
return Err(RoiError::RoiOutOfBounds);
|
||||
}
|
||||
Ok(self.slice_mut(ndarray::s![y1..y2, x1..x2, ..]))
|
||||
@@ -95,3 +102,47 @@ pub fn reborrow_test() {
|
||||
};
|
||||
dbg!(y);
|
||||
}
|
||||
|
||||
impl<'a> MultiRoi<'a, Vec<ArrayView3<'a, u8>>> for Array3<u8> {
|
||||
type Error = RoiError;
|
||||
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Vec<ArrayView3<'a, u8>>, Self::Error> {
|
||||
let (height, width, _channels) = self.dim();
|
||||
let outer_aabb = Aabb2::from_x1y1x2y2(0, 0, width, height);
|
||||
aabbs
|
||||
.iter()
|
||||
.map(|aabb| {
|
||||
let slice_arg =
|
||||
bbox_to_slice_arg(aabb.clamp(&outer_aabb).ok_or(RoiError::RoiOutOfBounds)?);
|
||||
Ok(self.slice(slice_arg))
|
||||
})
|
||||
.collect::<Result<Vec<_>, RoiError>>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> MultiRoi<'a, Vec<ArrayView3<'b, u8>>> for ArrayView3<'b, u8> {
|
||||
type Error = RoiError;
|
||||
fn multi_roi(&'a self, aabbs: &[Aabb2<usize>]) -> Result<Vec<ArrayView3<'b, u8>>, Self::Error> {
|
||||
let (height, width, _channels) = self.dim();
|
||||
let outer_aabb = Aabb2::from_x1y1x2y2(0, 0, width, height);
|
||||
aabbs
|
||||
.iter()
|
||||
.map(|aabb| {
|
||||
let slice_arg =
|
||||
bbox_to_slice_arg(aabb.clamp(&outer_aabb).ok_or(RoiError::RoiOutOfBounds)?);
|
||||
Ok(self.slice_move(slice_arg))
|
||||
})
|
||||
.collect::<Result<Vec<_>, RoiError>>()
|
||||
}
|
||||
}
|
||||
|
||||
fn bbox_to_slice_arg(
|
||||
aabb: Aabb2<usize>,
|
||||
) -> ndarray::SliceInfo<[ndarray::SliceInfoElem; 3], ndarray::Ix3, ndarray::Ix3> {
|
||||
// This function should convert the bounding box to a slice argument
|
||||
// For now, we will return a dummy value
|
||||
let x1 = aabb.x1();
|
||||
let x2 = aabb.x2();
|
||||
let y1 = aabb.y1();
|
||||
let y2 = aabb.y2();
|
||||
ndarray::s![y1..y2, x1..x2, ..]
|
||||
}
|
||||
|
||||
32
flake.lock
generated
32
flake.lock
generated
@@ -3,11 +3,11 @@
|
||||
"advisory-db": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1750151065,
|
||||
"narHash": "sha256-il+CAqChFIB82xP6bO43dWlUVs+NlG7a4g8liIP5HcI=",
|
||||
"lastModified": 1755283329,
|
||||
"narHash": "sha256-33bd+PHbon+cgEiWE/zkr7dpEF5E0DiHOzyoUQbkYBc=",
|
||||
"owner": "rustsec",
|
||||
"repo": "advisory-db",
|
||||
"rev": "7573f55ba337263f61167dbb0ea926cdc7c8eb5d",
|
||||
"rev": "61aac2116c8cb7cc80ff8ca283eec7687d384038",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -18,11 +18,11 @@
|
||||
},
|
||||
"crane": {
|
||||
"locked": {
|
||||
"lastModified": 1750266157,
|
||||
"narHash": "sha256-tL42YoNg9y30u7zAqtoGDNdTyXTi8EALDeCB13FtbQA=",
|
||||
"lastModified": 1754269165,
|
||||
"narHash": "sha256-0tcS8FHd4QjbCVoxN9jI+PjHgA4vc/IjkUSp+N3zy0U=",
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"rev": "e37c943371b73ed87faf33f7583860f81f1d5a48",
|
||||
"rev": "444e81206df3f7d92780680e45858e31d2f07a08",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -109,16 +109,16 @@
|
||||
"mnn-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1749173738,
|
||||
"narHash": "sha256-pNljvQ4xMZ4VmuxQyXt+boNBZD0+UZNpNLrWrj8Rtfw=",
|
||||
"lastModified": 1753256753,
|
||||
"narHash": "sha256-aTpwVZBkpQiwOVVXDfQIVEx9CswNiPbvNftw8KsoW+Q=",
|
||||
"owner": "alibaba",
|
||||
"repo": "MNN",
|
||||
"rev": "ebdada82634300956e08bd4056ecfeb1e4f23b32",
|
||||
"rev": "a739ea5870a4a45680f0e36ba9662ca39f2f4eec",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "alibaba",
|
||||
"ref": "3.2.0",
|
||||
"ref": "3.2.2",
|
||||
"repo": "MNN",
|
||||
"type": "github"
|
||||
}
|
||||
@@ -145,11 +145,11 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1750506804,
|
||||
"narHash": "sha256-VLFNc4egNjovYVxDGyBYTrvVCgDYgENp5bVi9fPTDYc=",
|
||||
"lastModified": 1755186698,
|
||||
"narHash": "sha256-wNO3+Ks2jZJ4nTHMuks+cxAiVBGNuEBXsT29Bz6HASo=",
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "4206c4cb56751df534751b058295ea61357bbbaa",
|
||||
"rev": "fbcf476f790d8a217c3eab4e12033dc4a0f6d23c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -178,11 +178,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1750732748,
|
||||
"narHash": "sha256-HR2b3RHsPeJm+Fb+1ui8nXibgniVj7hBNvUbXEyz0DU=",
|
||||
"lastModified": 1755485198,
|
||||
"narHash": "sha256-C3042ST2lUg0nh734gmuP4lRRIBitA6Maegg2/jYRM4=",
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "4b4494b2ba7e8a8041b2e28320b2ee02c115c75f",
|
||||
"rev": "aa45e63d431b28802ca4490cfc796b9e31731df7",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
129
flake.nix
129
flake.nix
@@ -22,7 +22,7 @@
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
mnn-src = {
|
||||
url = "github:alibaba/MNN/3.2.0";
|
||||
url = "github:alibaba/MNN/3.2.2";
|
||||
flake = false;
|
||||
};
|
||||
};
|
||||
@@ -43,13 +43,15 @@
|
||||
system: let
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
config.allowUnfree = true;
|
||||
config.cudaSupport = pkgs.stdenv.isLinux;
|
||||
overlays = [
|
||||
rust-overlay.overlays.default
|
||||
(final: prev: {
|
||||
mnn = mnn-overlay.packages.${system}.mnn.override {
|
||||
src = mnn-src;
|
||||
buildConverter = true;
|
||||
enableMetal = true;
|
||||
enableMetal = pkgs.stdenv.isDarwin;
|
||||
enableOpencl = true;
|
||||
};
|
||||
})
|
||||
@@ -61,17 +63,44 @@
|
||||
|
||||
stableToolchain = pkgs.rust-bin.stable.latest.default;
|
||||
stableToolchainWithLLvmTools = stableToolchain.override {
|
||||
extensions = ["rust-src" "llvm-tools"];
|
||||
extensions = [
|
||||
"rust-src"
|
||||
"llvm-tools"
|
||||
];
|
||||
};
|
||||
stableToolchainWithRustAnalyzer = stableToolchain.override {
|
||||
extensions = ["rust-src" "rust-analyzer"];
|
||||
extensions = [
|
||||
"rust-src"
|
||||
"rust-analyzer"
|
||||
];
|
||||
};
|
||||
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
|
||||
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
|
||||
|
||||
ort_static = (pkgs.onnxruntime.overide {cudaSupport = true;}).overrideAttrs (old: {
|
||||
cmakeFlags =
|
||||
old.cmakeFlags
|
||||
++ [
|
||||
"-Donnxruntime_BUILD_SHARED_LIB=OFF"
|
||||
"-Donnxruntime_BUILD_STATIC_LIB=ON"
|
||||
];
|
||||
});
|
||||
patchedOnnxruntime = pkgs.onnxruntime.overrideAttrs (old: {
|
||||
patches = [./patches/ort_env_global_mutex.patch];
|
||||
});
|
||||
src = let
|
||||
filterBySuffix = path: exts: lib.any (ext: lib.hasSuffix ext path) exts;
|
||||
sourceFilters = path: type: (craneLib.filterCargoSources path type) || filterBySuffix path [".c" ".h" ".hpp" ".cpp" ".cc"];
|
||||
sourceFilters = path: type:
|
||||
(craneLib.filterCargoSources path type)
|
||||
|| filterBySuffix path [
|
||||
".c"
|
||||
".h"
|
||||
".hpp"
|
||||
".cpp"
|
||||
".cc"
|
||||
".mnn"
|
||||
".onnx"
|
||||
];
|
||||
in
|
||||
lib.cleanSourceWith {
|
||||
filter = sourceFilters;
|
||||
@@ -81,15 +110,21 @@
|
||||
{
|
||||
inherit src;
|
||||
pname = name;
|
||||
stdenv = pkgs.clangStdenv;
|
||||
stdenv = p: p.clangStdenv;
|
||||
doCheck = false;
|
||||
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
|
||||
# nativeBuildInputs = with pkgs; [
|
||||
# cmake
|
||||
# llvmPackages.libclang.lib
|
||||
# ];
|
||||
# ORT_LIB_LOCATION = "${patchedOnnxruntime}";
|
||||
# ORT_ENV_SYSTEM_LIB_LOCATION = "${patchedOnnxruntime}/lib";
|
||||
# ORT_ENV_PREFER_DYNAMIC_LINK = true;
|
||||
nativeBuildInputs = with pkgs; [
|
||||
cmake
|
||||
pkg-config
|
||||
];
|
||||
buildInputs = with pkgs;
|
||||
[]
|
||||
[
|
||||
patchedOnnxruntime
|
||||
sqlite
|
||||
]
|
||||
++ (lib.optionals pkgs.stdenv.isDarwin [
|
||||
libiconv
|
||||
apple-sdk_13
|
||||
@@ -102,11 +137,13 @@
|
||||
in {
|
||||
checks =
|
||||
{
|
||||
"${name}-clippy" = craneLib.cargoClippy (commonArgs
|
||||
"${name}-clippy" = craneLib.cargoClippy (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoClippyExtraArgs = "--all-targets -- --deny warnings";
|
||||
});
|
||||
}
|
||||
);
|
||||
"${name}-docs" = craneLib.cargoDoc (commonArgs // {inherit cargoArtifacts;});
|
||||
"${name}-fmt" = craneLib.cargoFmt {inherit src;};
|
||||
"${name}-toml-fmt" = craneLib.taploFmt {
|
||||
@@ -121,22 +158,29 @@
|
||||
"${name}-deny" = craneLib.cargoDeny {
|
||||
inherit src;
|
||||
};
|
||||
"${name}-nextest" = craneLib.cargoNextest (commonArgs
|
||||
"${name}-nextest" = craneLib.cargoNextest (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
partitions = 1;
|
||||
partitionType = "count";
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
// lib.optionalAttrs (!pkgs.stdenv.isDarwin) {
|
||||
"${name}-llvm-cov" = craneLibLLvmTools.cargoLlvmCov (commonArgs // {inherit cargoArtifacts;});
|
||||
};
|
||||
|
||||
packages = let
|
||||
pkg = craneLib.buildPackage (commonArgs
|
||||
// {inherit cargoArtifacts;}
|
||||
pkg = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
nativeBuildInputs = with pkgs; [
|
||||
inherit cargoArtifacts;
|
||||
}
|
||||
// {
|
||||
nativeBuildInputs = with pkgs;
|
||||
commonArgs.nativeBuildInputs
|
||||
++ [
|
||||
installShellFiles
|
||||
];
|
||||
postInstall = ''
|
||||
@@ -145,28 +189,69 @@
|
||||
--fish <($out/bin/${name} completions fish) \
|
||||
--zsh <($out/bin/${name} completions zsh)
|
||||
'';
|
||||
});
|
||||
}
|
||||
);
|
||||
in {
|
||||
"${name}" = pkg;
|
||||
default = pkg;
|
||||
onnxruntime = ort_static;
|
||||
};
|
||||
|
||||
devShells = {
|
||||
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (commonArgs
|
||||
// {
|
||||
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (
|
||||
commonArgs
|
||||
// rec {
|
||||
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
|
||||
LD_LIBRARY_PATH = "$LD_LIBRARY_PATH:${builtins.toString (pkgs.lib.makeLibraryPath packages)}";
|
||||
packages = with pkgs;
|
||||
[
|
||||
stableToolchainWithRustAnalyzer
|
||||
cargo-expand
|
||||
cargo-outdated
|
||||
cargo-nextest
|
||||
cargo-deny
|
||||
cmake
|
||||
mnn
|
||||
cargo-make
|
||||
hyperfine
|
||||
opencv
|
||||
uv
|
||||
# (python312.withPackages (ps:
|
||||
# with ps; [
|
||||
# numpy
|
||||
# matplotlib
|
||||
# scikit-learn
|
||||
# opencv-python
|
||||
# seaborn
|
||||
# torch
|
||||
# torchvision
|
||||
# tensorflow-lite
|
||||
# retinaface
|
||||
# facenet-pytorch
|
||||
# tqdm
|
||||
# pillow
|
||||
# orjson
|
||||
# huggingface-hub
|
||||
# # insightface
|
||||
# ]))
|
||||
]
|
||||
++ (lib.optionals pkgs.stdenv.isDarwin [
|
||||
apple-sdk_13
|
||||
])
|
||||
++ (lib.optionals pkgs.stdenv.isLinux [
|
||||
xorg.libX11
|
||||
xorg.libXcursor
|
||||
xorg.libXrandr
|
||||
xorg.libXi
|
||||
xorg.libxcb
|
||||
libxkbcommon
|
||||
vulkan-loader
|
||||
wayland
|
||||
zenity
|
||||
cudatoolkit
|
||||
]);
|
||||
});
|
||||
}
|
||||
);
|
||||
};
|
||||
}
|
||||
)
|
||||
|
||||
15
justfile
15
justfile
@@ -1,2 +1,13 @@
|
||||
run:
|
||||
cargo run -r detect -- ./1000066593.jpg
|
||||
run_onnx ep = "cpu" arg = "selfie.jpg":
|
||||
cargo run -r detect -p {{ep}} -t 0.3 -o detected.jpg -- {{arg}}
|
||||
run_mnn forward = "cpu" arg = "selfie.jpg":
|
||||
cargo run -r detect -f {{forward}} -o detected.jpg -- {{arg}}
|
||||
|
||||
open:
|
||||
open detected.jpg
|
||||
|
||||
bench:
|
||||
cargo build --release
|
||||
BINARY="" hyperfine --warmup 3 --export-markdown benchmark.md \
|
||||
"$CARGO_TARGET_DIR/release/detector detect -f cpu selfie.jpg" \
|
||||
"$CARGO_TARGET_DIR/release/detector detect -f cpu -b 1 selfie.jpg"
|
||||
|
||||
Binary file not shown.
@@ -5,7 +5,7 @@ fn shape_error() -> ndarray::ShapeError {
|
||||
|
||||
mod rgb8 {
|
||||
use super::Result;
|
||||
pub(super) fn image_as_ndarray(image: &image::RgbImage) -> Result<ndarray::ArrayView3<u8>> {
|
||||
pub(super) fn image_as_ndarray(image: &image::RgbImage) -> Result<ndarray::ArrayView3<'_, u8>> {
|
||||
let (width, height) = image.dimensions();
|
||||
let data = image.as_raw();
|
||||
ndarray::ArrayView3::from_shape((height as usize, width as usize, 3), data)
|
||||
@@ -31,7 +31,9 @@ mod rgb8 {
|
||||
|
||||
mod rgba8 {
|
||||
use super::Result;
|
||||
pub(super) fn image_as_ndarray(image: &image::RgbaImage) -> Result<ndarray::ArrayView3<u8>> {
|
||||
pub(super) fn image_as_ndarray(
|
||||
image: &image::RgbaImage,
|
||||
) -> Result<ndarray::ArrayView3<'_, u8>> {
|
||||
let (width, height) = image.dimensions();
|
||||
let data = image.as_raw();
|
||||
ndarray::ArrayView3::from_shape((height as usize, width as usize, 4), data)
|
||||
@@ -57,7 +59,9 @@ mod rgba8 {
|
||||
|
||||
mod gray8 {
|
||||
use super::Result;
|
||||
pub(super) fn image_as_ndarray(image: &image::GrayImage) -> Result<ndarray::ArrayView2<u8>> {
|
||||
pub(super) fn image_as_ndarray(
|
||||
image: &image::GrayImage,
|
||||
) -> Result<ndarray::ArrayView2<'_, u8>> {
|
||||
let (width, height) = image.dimensions();
|
||||
let data = image.as_raw();
|
||||
ndarray::ArrayView2::from_shape((height as usize, width as usize), data)
|
||||
@@ -82,7 +86,7 @@ mod gray_alpha8 {
|
||||
use super::Result;
|
||||
pub(super) fn image_as_ndarray(
|
||||
image: &image::GrayAlphaImage,
|
||||
) -> Result<ndarray::ArrayView3<u8>> {
|
||||
) -> Result<ndarray::ArrayView3<'_, u8>> {
|
||||
let (width, height) = image.dimensions();
|
||||
let data = image.as_raw();
|
||||
ndarray::ArrayView3::from_shape((height as usize, width as usize, 2), data)
|
||||
@@ -110,7 +114,7 @@ mod gray_alpha8 {
|
||||
|
||||
mod dynamic_image {
|
||||
use super::*;
|
||||
pub fn image_as_ndarray(image: &image::DynamicImage) -> Result<ndarray::ArrayViewD<u8>> {
|
||||
pub fn image_as_ndarray(image: &image::DynamicImage) -> Result<ndarray::ArrayViewD<'_, u8>> {
|
||||
Ok(match image {
|
||||
image::DynamicImage::ImageRgb8(img) => rgb8::image_as_ndarray(img)?.into_dyn(),
|
||||
image::DynamicImage::ImageRgba8(img) => rgba8::image_as_ndarray(img)?.into_dyn(),
|
||||
|
||||
@@ -147,7 +147,7 @@ impl<S: ndarray::Data<Elem = T>, T: seal::Sealed + bytemuck::Pod, D: ndarray::Di
|
||||
NdAsImage<T, D> for ndarray::ArrayBase<S, D>
|
||||
{
|
||||
/// Clones self and makes a new image
|
||||
fn as_image_ref(&self) -> Result<ImageRef> {
|
||||
fn as_image_ref(&self) -> Result<ImageRef<'_>> {
|
||||
let shape = self.shape();
|
||||
let rows = *shape
|
||||
.first()
|
||||
|
||||
11
ndarray-safetensors/Cargo.toml
Normal file
11
ndarray-safetensors/Cargo.toml
Normal file
@@ -0,0 +1,11 @@
|
||||
[package]
|
||||
name = "ndarray-safetensors"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
bytemuck = { version = "1.23.2" }
|
||||
half = { version = "2.6.0", default-features = false, features = ["bytemuck"] }
|
||||
ndarray = { version = "0.16.1", default-features = false, features = ["std"] }
|
||||
safetensors = "0.6.2"
|
||||
thiserror = "2.0.15"
|
||||
449
ndarray-safetensors/src/lib.rs
Normal file
449
ndarray-safetensors/src/lib.rs
Normal file
@@ -0,0 +1,449 @@
|
||||
//! # ndarray-serialize
|
||||
//!
|
||||
//! A Rust library for serializing and deserializing `ndarray` arrays using the SafeTensors format.
|
||||
//!
|
||||
//! ## Features
|
||||
//! - Serialize `ndarray::ArrayView` to SafeTensors format
|
||||
//! - Deserialize SafeTensors data back to `ndarray::ArrayView`
|
||||
//! - Support for multiple data types (f32, f64, i8-i64, u8-u64, f16, bf16)
|
||||
//! - Zero-copy deserialization when possible
|
||||
//! - Metadata support
|
||||
//!
|
||||
//! ## Example
|
||||
//! ```rust
|
||||
//! use ndarray::Array2;
|
||||
//! use ndarray_safetensors::{SafeArrays, SafeArrayView};
|
||||
//!
|
||||
//! // Create some data
|
||||
//! let array = Array2::<f32>::zeros((3, 4));
|
||||
//!
|
||||
//! // Serialize
|
||||
//! let mut safe_arrays = SafeArrays::new();
|
||||
//! safe_arrays.insert_ndarray("my_tensor", array.view()).unwrap();
|
||||
//! safe_arrays.insert_metadata("author", "example");
|
||||
//! let bytes = safe_arrays.serialize().unwrap();
|
||||
//!
|
||||
//! // Deserialize
|
||||
//! let view = SafeArrayView::from_bytes(&bytes).unwrap();
|
||||
//! let tensor: ndarray::ArrayView2<f32> = view.tensor("my_tensor").unwrap();
|
||||
//! assert_eq!(tensor.shape(), &[3, 4]);
|
||||
//! ```
|
||||
|
||||
use safetensors::View;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
|
||||
use thiserror::Error;
|
||||
/// Errors that can occur during SafeTensor operations
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SafeTensorError {
|
||||
#[error("Tensor not found: {0}")]
|
||||
TensorNotFound(String),
|
||||
#[error("Invalid tensor data: Got {0} Expected: {1}")]
|
||||
InvalidTensorData(&'static str, String),
|
||||
#[error("IO error: {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
#[error("Safetensor error: {0}")]
|
||||
SafeTensor(#[from] safetensors::SafeTensorError),
|
||||
#[error("ndarray::ShapeError error: {0}")]
|
||||
NdarrayShapeError(#[from] ndarray::ShapeError),
|
||||
}
|
||||
|
||||
type Result<T, E = SafeTensorError> = core::result::Result<T, E>;
|
||||
|
||||
use safetensors::tensor::SafeTensors;
|
||||
|
||||
/// A view into SafeTensors data that provides access to ndarray tensors
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use ndarray::Array2;
|
||||
/// use ndarray_safetensors::{SafeArrays, SafeArrayView};
|
||||
///
|
||||
/// let array = Array2::<f32>::ones((2, 3));
|
||||
/// let mut safe_arrays = SafeArrays::new();
|
||||
/// safe_arrays.insert_ndarray("data", array.view()).unwrap();
|
||||
/// let bytes = safe_arrays.serialize().unwrap();
|
||||
///
|
||||
/// let view = SafeArrayView::from_bytes(&bytes).unwrap();
|
||||
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct SafeArraysView<'a> {
|
||||
pub tensors: SafeTensors<'a>,
|
||||
}
|
||||
|
||||
impl<'a> SafeArraysView<'a> {
|
||||
fn new(tensors: SafeTensors<'a>) -> Self {
|
||||
Self { tensors }
|
||||
}
|
||||
|
||||
/// Create a SafeArrayView from serialized bytes
|
||||
pub fn from_bytes(bytes: &'a [u8]) -> Result<SafeArraysView<'a>> {
|
||||
let tensors = SafeTensors::deserialize(bytes)?;
|
||||
Ok(Self::new(tensors))
|
||||
}
|
||||
|
||||
/// Get a dynamic-dimensional tensor by name
|
||||
pub fn dynamic_tensor<T: STDtype>(&self, name: &str) -> Result<ndarray::ArrayViewD<'a, T>> {
|
||||
self.tensors
|
||||
.tensor(name)
|
||||
.map(|tensor| tensor_view_to_array_view(tensor))?
|
||||
}
|
||||
|
||||
/// Get a tensor with specific dimensions by name
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// # use ndarray::Array2;
|
||||
/// # use ndarray_safetensors::{SafeArrays, SafeArrayView};
|
||||
/// # let array = Array2::<f32>::ones((2, 3));
|
||||
/// # let mut safe_arrays = SafeArrays::new();
|
||||
/// # safe_arrays.insert_ndarray("data", array.view()).unwrap();
|
||||
/// # let bytes = safe_arrays.serialize().unwrap();
|
||||
/// # let view = SafeArrayView::from_bytes(&bytes).unwrap();
|
||||
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
|
||||
/// ```
|
||||
pub fn tensor<T: STDtype, Dim: ndarray::Dimension>(
|
||||
&self,
|
||||
name: &str,
|
||||
) -> Result<ndarray::ArrayView<'a, T, Dim>> {
|
||||
Ok(self
|
||||
.tensors
|
||||
.tensor(name)
|
||||
.map(|tensor| tensor_view_to_array_view(tensor))?
|
||||
.map(|array_view| array_view.into_dimensionality::<Dim>())??)
|
||||
}
|
||||
|
||||
pub fn tensor_by_index<T: STDtype, Dim: ndarray::Dimension>(
|
||||
&self,
|
||||
index: usize,
|
||||
) -> Result<ndarray::ArrayView<'a, T, Dim>> {
|
||||
self.tensors
|
||||
.iter()
|
||||
.nth(index)
|
||||
.ok_or(SafeTensorError::TensorNotFound(format!(
|
||||
"Index {} out of bounds",
|
||||
index
|
||||
)))
|
||||
.map(|(_, tensor)| tensor_view_to_array_view(tensor))?
|
||||
.map(|array_view| array_view.into_dimensionality::<Dim>())?
|
||||
.map_err(SafeTensorError::NdarrayShapeError)
|
||||
}
|
||||
|
||||
/// Get an iterator over tensor names
|
||||
pub fn names(&self) -> std::vec::IntoIter<&str> {
|
||||
self.tensors.names().into_iter()
|
||||
}
|
||||
|
||||
/// Get the number of tensors
|
||||
pub fn len(&self) -> usize {
|
||||
self.tensors.len()
|
||||
}
|
||||
|
||||
/// Check if there are no tensors
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.tensors.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for types that can be stored in SafeTensors
|
||||
///
|
||||
/// Implemented for: f32, f64, i8, i16, i32, i64, u8, u16, u32, u64, f16, bf16
|
||||
pub trait STDtype: bytemuck::Pod {
|
||||
fn dtype() -> safetensors::tensor::Dtype;
|
||||
fn size() -> usize {
|
||||
(Self::dtype().bitsize() / 8).max(1)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_dtype {
|
||||
($($t:ty => $dtype:expr),* $(,)?) => {
|
||||
$(
|
||||
impl STDtype for $t {
|
||||
fn dtype() -> safetensors::tensor::Dtype {
|
||||
$dtype
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
use safetensors::tensor::Dtype;
|
||||
|
||||
impl_dtype!(
|
||||
// bool => Dtype::BOOL, // idk if ndarray::ArrayD<bool> is packed
|
||||
f32 => Dtype::F32,
|
||||
f64 => Dtype::F64,
|
||||
i8 => Dtype::I8,
|
||||
i16 => Dtype::I16,
|
||||
i32 => Dtype::I32,
|
||||
i64 => Dtype::I64,
|
||||
u8 => Dtype::U8,
|
||||
u16 => Dtype::U16,
|
||||
u32 => Dtype::U32,
|
||||
u64 => Dtype::U64,
|
||||
half::f16 => Dtype::F16,
|
||||
half::bf16 => Dtype::BF16,
|
||||
);
|
||||
|
||||
fn tensor_view_to_array_view<'a, T: STDtype>(
|
||||
tensor: safetensors::tensor::TensorView<'a>,
|
||||
) -> Result<ndarray::ArrayViewD<'a, T>> {
|
||||
let shape = tensor.shape();
|
||||
let dtype = tensor.dtype();
|
||||
if T::dtype() != dtype {
|
||||
return Err(SafeTensorError::InvalidTensorData(
|
||||
core::any::type_name::<T>(),
|
||||
dtype.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let data = tensor.data();
|
||||
let data: &[T] = bytemuck::cast_slice(data);
|
||||
let array = ndarray::ArrayViewD::from_shape(shape, data)?;
|
||||
Ok(array)
|
||||
}
|
||||
|
||||
/// Builder for creating SafeTensors data from ndarray tensors
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use ndarray::{Array1, Array2};
|
||||
/// use ndarray_safetensors::SafeArrays;
|
||||
///
|
||||
/// let mut safe_arrays = SafeArrays::new();
|
||||
///
|
||||
/// let array1 = Array1::<f32>::from_vec(vec![1.0, 2.0, 3.0]);
|
||||
/// let array2 = Array2::<i32>::zeros((2, 2));
|
||||
///
|
||||
/// safe_arrays.insert_ndarray("vector", array1.view()).unwrap();
|
||||
/// safe_arrays.insert_ndarray("matrix", array2.view()).unwrap();
|
||||
/// safe_arrays.insert_metadata("version", "1.0");
|
||||
///
|
||||
/// let bytes = safe_arrays.serialize().unwrap();
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Default)]
|
||||
#[non_exhaustive]
|
||||
pub struct SafeArrays<'a> {
|
||||
pub tensors: BTreeMap<String, SafeArray<'a>>,
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl<'a, K: AsRef<str>> FromIterator<(K, SafeArray<'a>)> for SafeArrays<'a> {
|
||||
fn from_iter<T: IntoIterator<Item = (K, SafeArray<'a>)>>(iter: T) -> Self {
|
||||
let tensors = iter
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.as_ref().to_owned(), v))
|
||||
.collect();
|
||||
Self {
|
||||
tensors,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, K: AsRef<str>, T: IntoIterator<Item = (K, SafeArray<'a>)>> From<T> for SafeArrays<'a> {
|
||||
fn from(iter: T) -> Self {
|
||||
let tensors = iter
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.as_ref().to_owned(), v))
|
||||
.collect();
|
||||
Self {
|
||||
tensors,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> SafeArrays<'a> {
|
||||
/// Create a SafeArrays from an iterator of (name, ndarray::ArrayView) pairs
|
||||
/// ```rust
|
||||
/// use ndarray::{Array2, Array3};
|
||||
/// use ndarray_safetensors::{SafeArrays, SafeArray};
|
||||
/// let array = Array2::<f32>::zeros((3, 4));
|
||||
/// let safe_arrays = SafeArrays::from_ndarrays(vec![
|
||||
/// ("test_tensor", array.view()),
|
||||
/// ("test_tensor2", array.view()),
|
||||
/// ]).unwrap();
|
||||
/// ```
|
||||
|
||||
pub fn from_ndarrays<
|
||||
K: AsRef<str>,
|
||||
T: STDtype,
|
||||
D: ndarray::Dimension + 'a,
|
||||
I: IntoIterator<Item = (K, ndarray::ArrayView<'a, T, D>)>,
|
||||
>(
|
||||
iter: I,
|
||||
) -> Result<Self> {
|
||||
let tensors = iter
|
||||
.into_iter()
|
||||
.map(|(k, v)| Ok((k.as_ref().to_owned(), SafeArray::from_ndarray(v)?)))
|
||||
.collect::<Result<BTreeMap<String, SafeArray<'a>>>>()?;
|
||||
Ok(Self {
|
||||
tensors,
|
||||
metadata: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// impl<'a, K: AsRef<str>, T: IntoIterator<Item = (K, SafeArray<'a>)>> From<T> for SafeArrays<'a> {
|
||||
// fn from(iter: T) -> Self {
|
||||
// let tensors = iter
|
||||
// .into_iter()
|
||||
// .map(|(k, v)| (k.as_ref().to_owned(), v))
|
||||
// .collect();
|
||||
// Self {
|
||||
// tensors,
|
||||
// metadata: None,
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<'a> SafeArrays<'a> {
|
||||
/// Create a new empty SafeArrays builder
|
||||
pub const fn new() -> Self {
|
||||
Self {
|
||||
tensors: BTreeMap::new(),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert a SafeArray tensor with the given name
|
||||
pub fn insert_tensor<'b: 'a>(&mut self, name: impl AsRef<str>, tensor: SafeArray<'b>) {
|
||||
self.tensors.insert(name.as_ref().to_owned(), tensor);
|
||||
}
|
||||
|
||||
/// Insert an ndarray tensor with the given name
|
||||
///
|
||||
/// The array must be in standard layout and contiguous.
|
||||
pub fn insert_ndarray<'b: 'a, T: STDtype, D: ndarray::Dimension + 'a>(
|
||||
&mut self,
|
||||
name: impl AsRef<str>,
|
||||
array: ndarray::ArrayView<'b, T, D>,
|
||||
) -> Result<()> {
|
||||
self.insert_tensor(name, SafeArray::from_ndarray(array)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Insert metadata key-value pair
|
||||
pub fn insert_metadata(&mut self, key: impl AsRef<str>, value: impl AsRef<str>) {
|
||||
self.metadata
|
||||
.get_or_insert_default()
|
||||
.insert(key.as_ref().to_owned(), value.as_ref().to_owned());
|
||||
}
|
||||
|
||||
/// Serialize all tensors and metadata to bytes
|
||||
pub fn serialize(self) -> Result<Vec<u8>> {
|
||||
let out = safetensors::serialize(self.tensors, self.metadata)
|
||||
.map_err(SafeTensorError::SafeTensor)?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
/// A tensor that can be serialized to SafeTensors format
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SafeArray<'a> {
|
||||
data: Cow<'a, [u8]>,
|
||||
shape: Vec<usize>,
|
||||
dtype: safetensors::tensor::Dtype,
|
||||
}
|
||||
|
||||
impl View for SafeArray<'_> {
|
||||
fn dtype(&self) -> safetensors::tensor::Dtype {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn shape(&self) -> &[usize] {
|
||||
&self.shape
|
||||
}
|
||||
|
||||
fn data(&self) -> Cow<'_, [u8]> {
|
||||
self.data.clone()
|
||||
}
|
||||
|
||||
fn data_len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> SafeArray<'a> {
|
||||
fn from_ndarray<'b: 'a, T: STDtype, D: ndarray::Dimension + 'a>(
|
||||
array: ndarray::ArrayView<'b, T, D>,
|
||||
) -> Result<Self> {
|
||||
let shape = array.shape().to_vec();
|
||||
let dtype = T::dtype();
|
||||
if array.ndim() == 0 {
|
||||
return Err(SafeTensorError::InvalidTensorData(
|
||||
core::any::type_name::<T>(),
|
||||
"Cannot insert a scalar tensor".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if !array.is_standard_layout() {
|
||||
return Err(SafeTensorError::InvalidTensorData(
|
||||
core::any::type_name::<T>(),
|
||||
"ArrayView is not standard layout".to_string(),
|
||||
));
|
||||
}
|
||||
let data =
|
||||
bytemuck::cast_slice(array.to_slice().ok_or(SafeTensorError::InvalidTensorData(
|
||||
core::any::type_name::<T>(),
|
||||
"ArrayView is not contiguous".to_string(),
|
||||
))?);
|
||||
let safe_array = SafeArray {
|
||||
data: Cow::Borrowed(data),
|
||||
shape,
|
||||
dtype,
|
||||
};
|
||||
Ok(safe_array)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_safe_array_from_ndarray() {
|
||||
use ndarray::Array2;
|
||||
|
||||
let array = Array2::<f32>::zeros((3, 4));
|
||||
let safe_array = SafeArray::from_ndarray(array.view()).unwrap();
|
||||
assert_eq!(safe_array.shape, vec![3, 4]);
|
||||
assert_eq!(safe_array.dtype, safetensors::tensor::Dtype::F32);
|
||||
assert_eq!(safe_array.data.len(), 3 * 4 * 4); // 3x4x4 bytes for f32
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize_safe_arrays() {
|
||||
use ndarray::{Array2, Array3};
|
||||
|
||||
let mut safe_arrays = SafeArrays::new();
|
||||
let array = Array2::<f32>::zeros((3, 4));
|
||||
let array2 = Array3::<u16>::zeros((8, 1, 9));
|
||||
safe_arrays
|
||||
.insert_ndarray("test_tensor", array.view())
|
||||
.unwrap();
|
||||
safe_arrays
|
||||
.insert_ndarray("test_tensor2", array2.view())
|
||||
.unwrap();
|
||||
safe_arrays.insert_metadata("author", "example");
|
||||
|
||||
let serialized = safe_arrays.serialize().unwrap();
|
||||
assert!(!serialized.is_empty());
|
||||
|
||||
// Deserialize to check if it works
|
||||
let deserialized = SafeArraysView::from_bytes(&serialized).unwrap();
|
||||
assert_eq!(deserialized.len(), 2);
|
||||
assert_eq!(
|
||||
deserialized
|
||||
.tensor::<f32, ndarray::Ix2>("test_tensor")
|
||||
.unwrap()
|
||||
.shape(),
|
||||
&[3, 4]
|
||||
);
|
||||
assert_eq!(
|
||||
deserialized
|
||||
.tensor::<u16, ndarray::Ix3>("test_tensor2")
|
||||
.unwrap()
|
||||
.shape(),
|
||||
&[8, 1, 9]
|
||||
);
|
||||
}
|
||||
36
ndcv-bridge/Cargo.toml
Normal file
36
ndcv-bridge/Cargo.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[package]
|
||||
name = "ndcv-bridge"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
bounding-box.workspace = true
|
||||
nalgebra.workspace = true
|
||||
bytemuck.workspace = true
|
||||
error-stack.workspace = true
|
||||
fast_image_resize.workspace = true
|
||||
ndarray = { workspace = true, features = ["rayon"] }
|
||||
num.workspace = true
|
||||
opencv = { workspace = true, optional = true }
|
||||
rayon = "1.10.0"
|
||||
thiserror.workspace = true
|
||||
tracing = "0.1.41"
|
||||
wide = "0.7.32"
|
||||
img-parts.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
divan.workspace = true
|
||||
ndarray-npy.workspace = true
|
||||
|
||||
[features]
|
||||
opencv = ["dep:opencv"]
|
||||
default = ["opencv"]
|
||||
|
||||
|
||||
[[bench]]
|
||||
name = "conversions"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "gaussian"
|
||||
harness = false
|
||||
75
ndcv-bridge/benches/conversions.rs
Normal file
75
ndcv-bridge/benches/conversions.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use divan::black_box;
|
||||
use ndcv_bridge::*;
|
||||
|
||||
// #[global_allocator]
|
||||
// static ALLOC: AllocProfiler = AllocProfiler::system();
|
||||
|
||||
fn main() {
|
||||
divan::main();
|
||||
}
|
||||
|
||||
#[divan::bench]
|
||||
fn bench_3d_mat_to_ndarray_512() {
|
||||
bench_mat_to_3d_ndarray(512);
|
||||
}
|
||||
|
||||
#[divan::bench]
|
||||
fn bench_3d_mat_to_ndarray_1024() {
|
||||
bench_mat_to_3d_ndarray(1024);
|
||||
}
|
||||
|
||||
#[divan::bench]
|
||||
fn bench_3d_mat_to_ndarray_2k() {
|
||||
bench_mat_to_3d_ndarray(2048);
|
||||
}
|
||||
|
||||
#[divan::bench]
|
||||
fn bench_3d_mat_to_ndarray_4k() {
|
||||
bench_mat_to_3d_ndarray(4096);
|
||||
}
|
||||
|
||||
#[divan::bench]
|
||||
fn bench_3d_mat_to_ndarray_8k() {
|
||||
bench_mat_to_3d_ndarray(8192);
|
||||
}
|
||||
|
||||
#[divan::bench]
|
||||
fn bench_3d_mat_to_ndarray_8k_ref() {
|
||||
bench_mat_to_3d_ndarray_ref(8192);
|
||||
}
|
||||
|
||||
#[divan::bench]
|
||||
fn bench_2d_mat_to_ndarray_8k_ref() {
|
||||
bench_mat_to_2d_ndarray(8192);
|
||||
}
|
||||
|
||||
fn bench_mat_to_2d_ndarray(size: i32) -> ndarray::Array2<u8> {
|
||||
let mat =
|
||||
opencv::core::Mat::new_nd_with_default(&[size, size], opencv::core::CV_8UC1, (200).into())
|
||||
.expect("failed");
|
||||
let ndarray: ndarray::Array2<u8> = mat.as_ndarray().expect("failed").to_owned();
|
||||
ndarray
|
||||
}
|
||||
|
||||
fn bench_mat_to_3d_ndarray(size: i32) -> ndarray::Array3<u8> {
|
||||
let mat = opencv::core::Mat::new_nd_with_default(
|
||||
&[size, size],
|
||||
opencv::core::CV_8UC3,
|
||||
(200, 100, 10).into(),
|
||||
)
|
||||
.expect("failed");
|
||||
// ndarray::Array3::<u8>::from_mat(black_box(mat)).expect("failed")
|
||||
let ndarray: ndarray::Array3<u8> = mat.as_ndarray().expect("failed").to_owned();
|
||||
ndarray
|
||||
}
|
||||
|
||||
fn bench_mat_to_3d_ndarray_ref(size: i32) {
|
||||
let mut mat = opencv::core::Mat::new_nd_with_default(
|
||||
&[size, size],
|
||||
opencv::core::CV_8UC3,
|
||||
(200, 100, 10).into(),
|
||||
)
|
||||
.expect("failed");
|
||||
let array: ndarray::ArrayView3<u8> = black_box(&mut mat).as_ndarray().expect("failed");
|
||||
let _ = black_box(array);
|
||||
}
|
||||
265
ndcv-bridge/benches/gaussian.rs
Normal file
265
ndcv-bridge/benches/gaussian.rs
Normal file
@@ -0,0 +1,265 @@
|
||||
use divan::black_box;
|
||||
use ndarray::*;
|
||||
use ndcv_bridge::*;
|
||||
|
||||
// #[global_allocator]
|
||||
// static ALLOC: AllocProfiler = AllocProfiler::system();
|
||||
|
||||
fn main() {
|
||||
divan::main();
|
||||
}
|
||||
|
||||
// Helper function to create test images with different patterns
|
||||
fn create_test_image(size: usize, pattern: &str) -> Array3<u8> {
|
||||
let mut arr = Array3::<u8>::zeros((size, size, 3));
|
||||
match pattern {
|
||||
"edges" => {
|
||||
// Create a pattern with sharp edges
|
||||
arr.slice_mut(s![size / 4..3 * size / 4, size / 4..3 * size / 4, ..])
|
||||
.fill(255);
|
||||
}
|
||||
"gradient" => {
|
||||
// Create a gradual gradient
|
||||
for i in 0..size {
|
||||
let val = (i * 255 / size) as u8;
|
||||
arr.slice_mut(s![i, .., ..]).fill(val);
|
||||
}
|
||||
}
|
||||
"checkerboard" => {
|
||||
// Create a checkerboard pattern
|
||||
for i in 0..size {
|
||||
for j in 0..size {
|
||||
if (i / 20 + j / 20) % 2 == 0 {
|
||||
arr[[i, j, 0]] = 255;
|
||||
arr[[i, j, 1]] = 255;
|
||||
arr[[i, j, 2]] = 255;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => arr.fill(255), // Default to solid white
|
||||
}
|
||||
arr
|
||||
}
|
||||
|
||||
#[divan::bench_group]
|
||||
mod sizes {
|
||||
use super::*;
|
||||
// Benchmark different image sizes
|
||||
#[divan::bench(args = [512, 1024, 2048, 4096])]
|
||||
fn bench_gaussian_sizes_u8(size: usize) {
|
||||
let arr = Array3::<u8>::ones((size, size, 3));
|
||||
let _out = black_box(
|
||||
arr.gaussian_blur((3, 3), 1.0, 1.0, BorderType::BorderConstant)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
#[divan::bench(args = [512, 1024, 2048, 4096])]
|
||||
fn bench_gaussian_sizes_u8_inplace(size: usize) {
|
||||
let mut arr = Array3::<u8>::ones((size, size, 3));
|
||||
black_box(
|
||||
arr.gaussian_blur_inplace((3, 3), 1.0, 1.0, BorderType::BorderConstant)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
#[divan::bench(args = [512, 1024, 2048, 4096])]
|
||||
fn bench_gaussian_sizes_f32(size: usize) {
|
||||
let arr = Array3::<f32>::ones((size, size, 3));
|
||||
let _out = black_box(
|
||||
arr.gaussian_blur((3, 3), 1.0, 1.0, BorderType::BorderConstant)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
#[divan::bench(args = [512, 1024, 2048, 4096])]
|
||||
fn bench_gaussian_sizes_f32_inplace(size: usize) {
|
||||
let mut arr = Array3::<f32>::ones((size, size, 3));
|
||||
black_box(
|
||||
arr.gaussian_blur_inplace((3, 3), 1.0, 1.0, BorderType::BorderConstant)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark different kernel sizes
|
||||
#[divan::bench(args = [(3, 3), (5, 5), (7, 7), (9, 9), (11, 11)])]
|
||||
fn bench_gaussian_kernels(kernel_size: (u8, u8)) {
|
||||
let mut arr = Array3::<u8>::ones((1000, 1000, 3));
|
||||
arr.gaussian_blur_inplace(kernel_size, 1.0, 1.0, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Benchmark different sigma values
|
||||
#[divan::bench(args = [0.5, 1.0, 2.0, 5.0])]
|
||||
fn bench_gaussian_sigmas(sigma: f64) {
|
||||
let mut arr = Array3::<u8>::ones((1000, 1000, 3));
|
||||
arr.gaussian_blur_inplace((3, 3), sigma, sigma, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Benchmark different sigma_x and sigma_y combinations
|
||||
#[divan::bench(args = [(0.5, 2.0), (1.0, 1.0), (2.0, 0.5), (3.0, 1.0)])]
|
||||
fn bench_gaussian_asymmetric_sigmas(sigmas: (f64, f64)) {
|
||||
let mut arr = Array3::<u8>::ones((1000, 1000, 3));
|
||||
arr.gaussian_blur_inplace((3, 3), sigmas.0, sigmas.1, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Benchmark different border types
|
||||
#[divan::bench]
|
||||
fn bench_gaussian_border_types() -> Vec<()> {
|
||||
let border_types = [
|
||||
BorderType::BorderConstant,
|
||||
BorderType::BorderReplicate,
|
||||
BorderType::BorderReflect,
|
||||
BorderType::BorderReflect101,
|
||||
];
|
||||
|
||||
let mut arr = Array3::<u8>::ones((1000, 1000, 3));
|
||||
border_types
|
||||
.iter()
|
||||
.map(|border_type| {
|
||||
arr.gaussian_blur_inplace((3, 3), 1.0, 1.0, *border_type)
|
||||
.unwrap();
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Benchmark different image patterns
|
||||
#[divan::bench]
|
||||
fn bench_gaussian_patterns() {
|
||||
let patterns = ["edges", "gradient", "checkerboard", "solid"];
|
||||
|
||||
patterns.iter().for_each(|&pattern| {
|
||||
let mut arr = create_test_image(1000, pattern);
|
||||
arr.gaussian_blur_inplace((3, 3), 1.0, 1.0, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
})
|
||||
}
|
||||
|
||||
#[divan::bench_group]
|
||||
mod allocation {
|
||||
use super::*;
|
||||
#[divan::bench]
|
||||
fn bench_gaussian_allocation_inplace() {
|
||||
let mut arr = Array3::<f32>::ones((3840, 2160, 3));
|
||||
|
||||
black_box(
|
||||
arr.gaussian_blur_inplace((3, 3), 1.0, 1.0, BorderType::BorderConstant)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
#[divan::bench]
|
||||
fn bench_gaussian_allocation_allocate() {
|
||||
let arr = Array3::<f32>::ones((3840, 2160, 3));
|
||||
|
||||
let _out = black_box(
|
||||
arr.gaussian_blur((3, 3), 1.0, 1.0, BorderType::BorderConstant)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[divan::bench_group]
|
||||
mod realistic {
|
||||
use super::*;
|
||||
#[divan::bench]
|
||||
fn small_800_600_3x3() {
|
||||
let small_blur = Array3::<u8>::ones((800, 600, 3));
|
||||
let _blurred = black_box(
|
||||
small_blur
|
||||
.gaussian_blur((3, 3), 0.5, 0.5, BorderType::BorderConstant)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
#[divan::bench]
|
||||
fn small_800_600_3x3_inplace() {
|
||||
let mut small_blur = Array3::<u8>::ones((800, 600, 3));
|
||||
small_blur
|
||||
.gaussian_blur_inplace((3, 3), 0.5, 0.5, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
}
|
||||
#[divan::bench]
|
||||
fn medium_1920x1080_5x5() {
|
||||
let mut medium_blur = Array3::<u8>::ones((1920, 1080, 3));
|
||||
let _blurred = black_box(
|
||||
medium_blur
|
||||
.gaussian_blur_inplace((5, 5), 2.0, 2.0, BorderType::BorderConstant)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
#[divan::bench]
|
||||
fn medium_1920x1080_5x5_inplace() {
|
||||
let mut medium_blur = Array3::<u8>::ones((1920, 1080, 3));
|
||||
medium_blur
|
||||
.gaussian_blur_inplace((5, 5), 2.0, 2.0, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
}
|
||||
#[divan::bench]
|
||||
fn large_3840x2160_9x9() {
|
||||
let large_blur = Array3::<u8>::ones((3840, 2160, 3));
|
||||
let _blurred = black_box(
|
||||
large_blur
|
||||
.gaussian_blur((9, 9), 5.0, 5.0, BorderType::BorderConstant)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
#[divan::bench]
|
||||
fn large_3840x2160_9x9_inplace() {
|
||||
let mut large_blur = Array3::<u8>::ones((3840, 2160, 3));
|
||||
large_blur
|
||||
.gaussian_blur_inplace((9, 9), 5.0, 5.0, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
}
|
||||
#[divan::bench]
|
||||
fn small_800_600_3x3_f32() {
|
||||
let small_blur = Array3::<f32>::ones((800, 600, 3));
|
||||
let _blurred = black_box(
|
||||
small_blur
|
||||
.gaussian_blur((3, 3), 0.5, 0.5, BorderType::BorderConstant)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
#[divan::bench]
|
||||
fn small_800_600_3x3_inplace_f32() {
|
||||
let mut small_blur = Array3::<f32>::ones((800, 600, 3));
|
||||
small_blur
|
||||
.gaussian_blur_inplace((3, 3), 0.5, 0.5, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
}
|
||||
#[divan::bench]
|
||||
fn medium_1920x1080_5x5_f32() {
|
||||
let mut medium_blur = Array3::<f32>::ones((1920, 1080, 3));
|
||||
let _blurred = black_box(
|
||||
medium_blur
|
||||
.gaussian_blur_inplace((5, 5), 2.0, 2.0, BorderType::BorderConstant)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
#[divan::bench]
|
||||
fn medium_1920x1080_5x5_inplace_f32() {
|
||||
let mut medium_blur = Array3::<f32>::ones((1920, 1080, 3));
|
||||
medium_blur
|
||||
.gaussian_blur_inplace((5, 5), 2.0, 2.0, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
}
|
||||
#[divan::bench]
|
||||
fn large_3840x2160_9x9_f32() {
|
||||
let large_blur = Array3::<f32>::ones((3840, 2160, 3));
|
||||
let _blurred = black_box(
|
||||
large_blur
|
||||
.gaussian_blur((9, 9), 5.0, 5.0, BorderType::BorderConstant)
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
#[divan::bench]
|
||||
fn large_3840x2160_9x9_inplace_f32() {
|
||||
let mut large_blur = Array3::<f32>::ones((3840, 2160, 3));
|
||||
large_blur
|
||||
.gaussian_blur_inplace((9, 9), 5.0, 5.0, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
180
ndcv-bridge/src/blend.rs
Normal file
180
ndcv-bridge/src/blend.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
use crate::prelude_::*;
|
||||
use ndarray::*;
|
||||
|
||||
type Result<T, E = Report<NdCvError>> = std::result::Result<T, E>;
|
||||
|
||||
mod seal {
|
||||
pub trait Sealed {}
|
||||
impl<T: ndarray::Data<Elem = f32>> Sealed for ndarray::ArrayBase<T, ndarray::Ix3> {}
|
||||
}
|
||||
pub trait NdBlend<T, D: ndarray::Dimension>: seal::Sealed {
|
||||
fn blend(
|
||||
&self,
|
||||
mask: ndarray::ArrayView<T, D::Smaller>,
|
||||
other: ndarray::ArrayView<T, D>,
|
||||
alpha: T,
|
||||
) -> Result<ndarray::Array<T, D>>;
|
||||
fn blend_inplace(
|
||||
&mut self,
|
||||
mask: ndarray::ArrayView<T, D::Smaller>,
|
||||
other: ndarray::ArrayView<T, D>,
|
||||
alpha: T,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
impl<S> NdBlend<f32, Ix3> for ndarray::ArrayBase<S, Ix3>
|
||||
where
|
||||
S: ndarray::DataMut<Elem = f32>,
|
||||
{
|
||||
fn blend(
|
||||
&self,
|
||||
mask: ndarray::ArrayView<f32, Ix2>,
|
||||
other: ndarray::ArrayView<f32, Ix3>,
|
||||
alpha: f32,
|
||||
) -> Result<ndarray::Array<f32, Ix3>> {
|
||||
if self.shape() != other.shape() {
|
||||
return Err(NdCvError)
|
||||
.attach_printable("Shapes of image and other imagge do not match");
|
||||
}
|
||||
if self.shape()[0] != mask.shape()[0] || self.shape()[1] != mask.shape()[1] {
|
||||
return Err(NdCvError).attach_printable("Shapes of image and mask do not match");
|
||||
}
|
||||
|
||||
let mut output = ndarray::Array3::zeros(self.dim());
|
||||
let (_height, _width, channels) = self.dim();
|
||||
|
||||
Zip::from(output.lanes_mut(Axis(2)))
|
||||
.and(self.lanes(Axis(2)))
|
||||
.and(other.lanes(Axis(2)))
|
||||
.and(mask)
|
||||
.par_for_each(|mut out, this, other, mask| {
|
||||
let this = wide::f32x4::from(this.as_slice().expect("Invalid self array"));
|
||||
let other = wide::f32x4::from(other.as_slice().expect("Invalid other array"));
|
||||
let mask = wide::f32x4::splat(mask * alpha);
|
||||
let o = this * (1.0 - mask) + other * mask;
|
||||
out.as_slice_mut()
|
||||
.expect("Failed to get mutable slice")
|
||||
.copy_from_slice(&o.as_array_ref()[..channels]);
|
||||
});
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn blend_inplace(
|
||||
&mut self,
|
||||
mask: ndarray::ArrayView<f32, <Ix3 as Dimension>::Smaller>,
|
||||
other: ndarray::ArrayView<f32, Ix3>,
|
||||
alpha: f32,
|
||||
) -> Result<()> {
|
||||
if self.shape() != other.shape() {
|
||||
return Err(NdCvError)
|
||||
.attach_printable("Shapes of image and other imagge do not match");
|
||||
}
|
||||
if self.shape()[0] != mask.shape()[0] || self.shape()[1] != mask.shape()[1] {
|
||||
return Err(NdCvError).attach_printable("Shapes of image and mask do not match");
|
||||
}
|
||||
|
||||
let (_height, _width, channels) = self.dim();
|
||||
|
||||
// Zip::from(self.lanes_mut(Axis(2)))
|
||||
// .and(other.lanes(Axis(2)))
|
||||
// .and(mask)
|
||||
// .par_for_each(|mut this, other, mask| {
|
||||
// let this_wide = wide::f32x4::from(this.as_slice().expect("Invalid self array"));
|
||||
// let other = wide::f32x4::from(other.as_slice().expect("Invalid other array"));
|
||||
// let mask = wide::f32x4::splat(mask * alpha);
|
||||
// let o = this_wide * (1.0 - mask) + other * mask;
|
||||
// this.as_slice_mut()
|
||||
// .expect("Failed to get mutable slice")
|
||||
// .copy_from_slice(&o.as_array_ref()[..channels]);
|
||||
// });
|
||||
let this = self
|
||||
.as_slice_mut()
|
||||
.ok_or(NdCvError)
|
||||
.attach_printable("Failed to get source image as a continuous slice")?;
|
||||
let other = other
|
||||
.as_slice()
|
||||
.ok_or(NdCvError)
|
||||
.attach_printable("Failed to get other image as a continuous slice")?;
|
||||
let mask = mask
|
||||
.as_slice()
|
||||
.ok_or(NdCvError)
|
||||
.attach_printable("Failed to get mask as a continuous slice")?;
|
||||
|
||||
use rayon::prelude::*;
|
||||
this.par_chunks_exact_mut(channels)
|
||||
.zip(other.par_chunks_exact(channels))
|
||||
.zip(mask)
|
||||
.for_each(|((this, other), mask)| {
|
||||
let this_wide = wide::f32x4::from(&*this);
|
||||
let other = wide::f32x4::from(other);
|
||||
let mask = wide::f32x4::splat(mask * alpha);
|
||||
this.copy_from_slice(
|
||||
&(this_wide * (1.0 - mask) + other * mask).as_array_ref()[..channels],
|
||||
);
|
||||
});
|
||||
|
||||
// for h in 0.._height {
|
||||
// for w in 0.._width {
|
||||
// let mask_index = h * _width + w;
|
||||
// let mask = mask[mask_index];
|
||||
// let mask = wide::f32x4::splat(mask * alpha);
|
||||
// let this = &mut this[mask_index * channels..(mask_index + 1) * channels];
|
||||
// let other = &other[mask_index * channels..(mask_index + 1) * channels];
|
||||
// let this_wide = wide::f32x4::from(&*this);
|
||||
// let other = wide::f32x4::from(other);
|
||||
// let o = this_wide * (1.0 - mask) + other * mask;
|
||||
// this.copy_from_slice(&o.as_array_ref()[..channels]);
|
||||
// }
|
||||
// }
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_blend() {
|
||||
let img = Array3::<f32>::from_shape_fn((10, 10, 3), |(i, j, k)| match (i, j, k) {
|
||||
(0..=3, _, 0) => 1f32, // red
|
||||
(4..=6, _, 1) => 1f32, // green
|
||||
(7..=9, _, 2) => 1f32, // blue
|
||||
_ => 0f32,
|
||||
});
|
||||
let other = img.clone().permuted_axes([1, 0, 2]).to_owned();
|
||||
let mask = Array2::<f32>::from_shape_fn((10, 10), |(_, j)| if j > 5 { 1f32 } else { 0f32 });
|
||||
// let other = Array3::<f32>::zeros((10, 10, 3));
|
||||
let out = img.blend(mask.view(), other.view(), 1f32).unwrap();
|
||||
let out_u8 = out.mapv(|v| (v * 255f32) as u8);
|
||||
let expected = Array3::<u8>::from_shape_fn((10, 10, 3), |(i, j, k)| {
|
||||
match (i, j, k) {
|
||||
(0..=3, 0..=5, 0) => u8::MAX, // red
|
||||
(4..=6, 0..=5, 1) | (_, 6, 1) => u8::MAX, // green
|
||||
(7..=9, 0..=5, 2) | (_, 7..=10, 2) => u8::MAX, // blue
|
||||
_ => u8::MIN,
|
||||
}
|
||||
});
|
||||
assert_eq!(out_u8, expected);
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// pub fn test_blend_inplace() {
|
||||
// let mut img = Array3::<f32>::from_shape_fn((10, 10, 3), |(i, j, k)| match (i, j, k) {
|
||||
// (0..=3, _, 0) => 1f32, // red
|
||||
// (4..=6, _, 1) => 1f32, // green
|
||||
// (7..=9, _, 2) => 1f32, // blue
|
||||
// _ => 0f32,
|
||||
// });
|
||||
// let other = img.clone().permuted_axes([1, 0, 2]);
|
||||
// let mask = Array2::<f32>::from_shape_fn((10, 10), |(_, j)| if j > 5 { 1f32 } else { 0f32 });
|
||||
// // let other = Array3::<f32>::zeros((10, 10, 3));
|
||||
// img.blend_inplace(mask.view(), other.view(), 1f32).unwrap();
|
||||
// let out_u8 = img.mapv(|v| (v * 255f32) as u8);
|
||||
// let expected = Array3::<u8>::from_shape_fn((10, 10, 3), |(i, j, k)| {
|
||||
// match (i, j, k) {
|
||||
// (0..=3, 0..=5, 0) => u8::MAX, // red
|
||||
// (4..=6, 0..=5, 1) | (_, 6, 1) => u8::MAX, // green
|
||||
// (7..=9, 0..=5, 2) | (_, 7..=10, 2) => u8::MAX, // blue
|
||||
// _ => u8::MIN,
|
||||
// }
|
||||
// });
|
||||
// assert_eq!(out_u8, expected);
|
||||
// }
|
||||
48
ndcv-bridge/src/bounding_rect.rs
Normal file
48
ndcv-bridge/src/bounding_rect.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
//! Calculates the up-right bounding rectangle of a point set or non-zero pixels of gray-scale image.
|
||||
//! The function calculates and returns the minimal up-right bounding rectangle for the specified point set or non-zero pixels of gray-scale image.
|
||||
use crate::{NdAsImage, prelude_::*};
|
||||
pub trait BoundingRect: seal::SealedInternal {
|
||||
fn bounding_rect(&self) -> Result<bounding_box::Aabb2<i32>, NdCvError>;
|
||||
}
|
||||
|
||||
mod seal {
|
||||
pub trait SealedInternal {}
|
||||
impl<T, S: ndarray::Data<Elem = T>> SealedInternal for ndarray::ArrayBase<S, ndarray::Ix2> {}
|
||||
}
|
||||
|
||||
impl<S: ndarray::Data<Elem = u8>> BoundingRect for ndarray::ArrayBase<S, ndarray::Ix2> {
|
||||
fn bounding_rect(&self) -> Result<bounding_box::Aabb2<i32>, NdCvError> {
|
||||
let mat = self.as_image_mat()?;
|
||||
let rect = opencv::imgproc::bounding_rect(mat.as_ref()).change_context(NdCvError)?;
|
||||
Ok(bounding_box::Aabb2::from_xywh(
|
||||
rect.x,
|
||||
rect.y,
|
||||
rect.width,
|
||||
rect.height,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_rect_empty() {
|
||||
let arr = ndarray::Array2::<u8>::zeros((10, 10));
|
||||
let rect = arr.bounding_rect().unwrap();
|
||||
assert_eq!(rect, bounding_box::Aabb2::from_xywh(0, 0, 0, 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_rect_valued() {
|
||||
let mut arr = ndarray::Array2::<u8>::zeros((10, 10));
|
||||
crate::NdRoiMut::roi_mut(&mut arr, bounding_box::Aabb2::from_xywh(1, 1, 3, 3)).fill(1);
|
||||
let rect = arr.bounding_rect().unwrap();
|
||||
assert_eq!(rect, bounding_box::Aabb2::from_xywh(1, 1, 3, 3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_rect_complex() {
|
||||
let mut arr = ndarray::Array2::<u8>::zeros((10, 10));
|
||||
crate::NdRoiMut::roi_mut(&mut arr, bounding_box::Aabb2::from_xywh(1, 3, 3, 3)).fill(1);
|
||||
crate::NdRoiMut::roi_mut(&mut arr, bounding_box::Aabb2::from_xywh(2, 3, 3, 5)).fill(5);
|
||||
let rect = arr.bounding_rect().unwrap();
|
||||
assert_eq!(rect, bounding_box::Aabb2::from_xywh(1, 3, 4, 5));
|
||||
}
|
||||
4
ndcv-bridge/src/codec.rs
Normal file
4
ndcv-bridge/src/codec.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod codecs;
|
||||
pub mod decode;
|
||||
pub mod encode;
|
||||
pub mod error;
|
||||
218
ndcv-bridge/src/codec/codecs.rs
Normal file
218
ndcv-bridge/src/codec/codecs.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
use super::decode::Decoder;
|
||||
use super::encode::Encoder;
|
||||
use crate::NdCvError;
|
||||
use crate::conversions::matref::MatRef;
|
||||
use error_stack::*;
|
||||
use img_parts::{
|
||||
Bytes,
|
||||
jpeg::{Jpeg, markers},
|
||||
};
|
||||
use opencv::{
|
||||
core::{Mat, Vector, VectorToVec},
|
||||
imgcodecs::{ImreadModes, ImwriteFlags, imdecode, imencode},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CvEncoder {
|
||||
Jpeg(CvJpegEncFlags),
|
||||
Tiff(CvTiffEncFlags),
|
||||
}
|
||||
|
||||
pub enum EncKind {
|
||||
Jpeg,
|
||||
Tiff,
|
||||
}
|
||||
|
||||
impl CvEncoder {
|
||||
fn kind(&self) -> EncKind {
|
||||
match self {
|
||||
Self::Jpeg(_) => EncKind::Jpeg,
|
||||
Self::Tiff(_) => EncKind::Tiff,
|
||||
}
|
||||
}
|
||||
|
||||
fn extension(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Jpeg(_) => ".jpg",
|
||||
Self::Tiff(_) => ".tiff",
|
||||
}
|
||||
}
|
||||
|
||||
fn to_cv_param_list(&self) -> Vector<i32> {
|
||||
match self {
|
||||
Self::Jpeg(flags) => flags.to_cv_param_list(),
|
||||
Self::Tiff(flags) => flags.to_cv_param_list(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct CvJpegEncFlags {
|
||||
quality: Option<usize>,
|
||||
progressive: Option<bool>,
|
||||
optimize: Option<bool>,
|
||||
remove_app0: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct CvTiffEncFlags {
|
||||
compression: Option<i32>,
|
||||
}
|
||||
|
||||
impl CvTiffEncFlags {
|
||||
pub fn new() -> Self {
|
||||
Self::default().with_compression(1)
|
||||
}
|
||||
|
||||
pub fn with_compression(mut self, compression: i32) -> Self {
|
||||
self.compression = Some(compression);
|
||||
self
|
||||
}
|
||||
|
||||
fn to_cv_param_list(&self) -> Vector<i32> {
|
||||
let iter = [(
|
||||
ImwriteFlags::IMWRITE_TIFF_COMPRESSION as i32,
|
||||
self.compression.map(|i| i as i32),
|
||||
)]
|
||||
.into_iter()
|
||||
.filter_map(|(flag, opt)| opt.map(|o| [flag, o]))
|
||||
.flatten();
|
||||
|
||||
Vector::from_iter(iter)
|
||||
}
|
||||
}
|
||||
|
||||
impl CvJpegEncFlags {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_quality(mut self, quality: usize) -> Self {
|
||||
self.quality = Some(quality);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn remove_app0_marker(mut self, val: bool) -> Self {
|
||||
self.remove_app0 = Some(val);
|
||||
self
|
||||
}
|
||||
|
||||
fn to_cv_param_list(&self) -> Vector<i32> {
|
||||
let iter = [
|
||||
(
|
||||
ImwriteFlags::IMWRITE_JPEG_QUALITY as i32,
|
||||
self.quality.map(|i| i as i32),
|
||||
),
|
||||
(
|
||||
ImwriteFlags::IMWRITE_JPEG_PROGRESSIVE as i32,
|
||||
self.progressive.map(|i| i as i32),
|
||||
),
|
||||
(
|
||||
ImwriteFlags::IMWRITE_JPEG_OPTIMIZE as i32,
|
||||
self.optimize.map(|i| i as i32),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.filter_map(|(flag, opt)| opt.map(|o| [flag, o]))
|
||||
.flatten();
|
||||
|
||||
Vector::from_iter(iter)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder for CvEncoder {
|
||||
type Input<'a>
|
||||
= MatRef<'a>
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
fn encode(&self, input: Self::Input<'_>) -> Result<Vec<u8>, NdCvError> {
|
||||
let mut buf = Vector::default();
|
||||
|
||||
let params = self.to_cv_param_list();
|
||||
|
||||
imencode(self.extension(), &input.as_ref(), &mut buf, ¶ms).change_context(NdCvError)?;
|
||||
|
||||
match self.kind() {
|
||||
EncKind::Jpeg => {
|
||||
let bytes = Bytes::from(buf.to_vec());
|
||||
let mut jpg = Jpeg::from_bytes(bytes).change_context(NdCvError)?;
|
||||
jpg.remove_segments_by_marker(markers::APP0);
|
||||
let bytes = jpg.encoder().bytes();
|
||||
Ok(bytes.to_vec())
|
||||
}
|
||||
EncKind::Tiff => Ok(buf.to_vec()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum CvDecoder {
|
||||
Jpeg(CvJpegDecFlags),
|
||||
}
|
||||
|
||||
impl CvDecoder {
|
||||
fn to_cv_decode_flag(&self) -> i32 {
|
||||
match self {
|
||||
Self::Jpeg(flags) => flags.to_cv_decode_flag(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub enum ColorMode {
|
||||
#[default]
|
||||
Color,
|
||||
GrayScale,
|
||||
}
|
||||
|
||||
impl ColorMode {
|
||||
fn to_cv_decode_flag(&self) -> i32 {
|
||||
match self {
|
||||
Self::Color => ImreadModes::IMREAD_ANYCOLOR as i32,
|
||||
Self::GrayScale => ImreadModes::IMREAD_GRAYSCALE as i32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CvJpegDecFlags {
|
||||
color_mode: ColorMode,
|
||||
ignore_orientation: bool,
|
||||
}
|
||||
|
||||
impl CvJpegDecFlags {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_color_mode(mut self, color_mode: ColorMode) -> Self {
|
||||
self.color_mode = color_mode;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_ignore_orientation(mut self, ignore_orientation: bool) -> Self {
|
||||
self.ignore_orientation = ignore_orientation;
|
||||
self
|
||||
}
|
||||
|
||||
fn to_cv_decode_flag(&self) -> i32 {
|
||||
let flag = self.color_mode.to_cv_decode_flag();
|
||||
|
||||
if self.ignore_orientation {
|
||||
flag | ImreadModes::IMREAD_IGNORE_ORIENTATION as i32
|
||||
} else {
|
||||
flag
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for CvDecoder {
|
||||
type Output = Mat;
|
||||
|
||||
fn decode(&self, input: impl AsRef<[u8]>) -> Result<Self::Output, NdCvError> {
|
||||
let flag = self.to_cv_decode_flag();
|
||||
let out = imdecode(&Vector::from_slice(input.as_ref()), flag).change_context(NdCvError)?;
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
61
ndcv-bridge/src/codec/decode.rs
Normal file
61
ndcv-bridge/src/codec/decode.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
#![deny(warnings)]
|
||||
|
||||
use super::codecs::CvDecoder;
|
||||
use super::error::ErrorReason;
|
||||
use crate::NdCvError;
|
||||
use crate::{NdAsImage, conversions::NdCvConversion};
|
||||
use error_stack::*;
|
||||
use ndarray::Array;
|
||||
use std::path::Path;
|
||||
|
||||
pub trait Decodable<D: Decoder>: Sized {
|
||||
fn decode(buf: impl AsRef<[u8]>, decoder: &D) -> Result<Self, NdCvError> {
|
||||
let output = decoder.decode(buf)?;
|
||||
Self::transform(output)
|
||||
}
|
||||
|
||||
fn read(&self, path: impl AsRef<Path>, decoder: &D) -> Result<Self, NdCvError> {
|
||||
let buf = std::fs::read(path)
|
||||
.map_err(|e| match e.kind() {
|
||||
std::io::ErrorKind::NotFound => {
|
||||
Report::new(e).attach_printable(ErrorReason::ImageWriteFileNotFound)
|
||||
}
|
||||
std::io::ErrorKind::PermissionDenied => {
|
||||
Report::new(e).attach_printable(ErrorReason::ImageWritePermissionDenied)
|
||||
}
|
||||
std::io::ErrorKind::OutOfMemory => {
|
||||
Report::new(e).attach_printable(ErrorReason::OutOfMemory)
|
||||
}
|
||||
std::io::ErrorKind::StorageFull => {
|
||||
Report::new(e).attach_printable(ErrorReason::OutOfStorage)
|
||||
}
|
||||
_ => Report::new(e).attach_printable(ErrorReason::ImageWriteOtherError),
|
||||
})
|
||||
.change_context(NdCvError)?;
|
||||
Self::decode(buf, decoder)
|
||||
}
|
||||
|
||||
fn transform(input: D::Output) -> Result<Self, NdCvError>;
|
||||
}
|
||||
|
||||
pub trait Decoder {
|
||||
type Output: Sized;
|
||||
fn decode(&self, buf: impl AsRef<[u8]>) -> Result<Self::Output, NdCvError>;
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod + Copy, D: ndarray::Dimension> Decodable<CvDecoder> for Array<T, D>
|
||||
where
|
||||
Self: NdAsImage<T, D>,
|
||||
{
|
||||
fn transform(input: <CvDecoder as Decoder>::Output) -> Result<Self, NdCvError> {
|
||||
Self::from_mat(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_image() {
|
||||
use crate::codec::codecs::*;
|
||||
let img = std::fs::read("/Users/fs0c131y/Projects/face-detector/assets/selfie.jpg").unwrap();
|
||||
let decoder = CvDecoder::Jpeg(CvJpegDecFlags::new().with_ignore_orientation(true));
|
||||
let _out = ndarray::Array3::<u8>::decode(img, &decoder).unwrap();
|
||||
}
|
||||
56
ndcv-bridge/src/codec/encode.rs
Normal file
56
ndcv-bridge/src/codec/encode.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use super::codecs::CvEncoder;
|
||||
use super::error::ErrorReason;
|
||||
use crate::conversions::NdAsImage;
|
||||
use crate::NdCvError;
|
||||
use error_stack::*;
|
||||
use ndarray::ArrayBase;
|
||||
use std::path::Path;
|
||||
|
||||
pub trait Encodable<E: Encoder> {
|
||||
fn encode(&self, encoder: &E) -> Result<Vec<u8>, NdCvError> {
|
||||
let input = self.transform()?;
|
||||
encoder.encode(input)
|
||||
}
|
||||
|
||||
fn write(&self, path: impl AsRef<Path>, encoder: &E) -> Result<(), NdCvError> {
|
||||
let buf = self.encode(encoder)?;
|
||||
|
||||
std::fs::write(path, buf)
|
||||
.map_err(|e| match e.kind() {
|
||||
std::io::ErrorKind::NotFound => {
|
||||
Report::new(e).attach_printable(ErrorReason::ImageWriteFileNotFound)
|
||||
}
|
||||
std::io::ErrorKind::PermissionDenied => {
|
||||
Report::new(e).attach_printable(ErrorReason::ImageWritePermissionDenied)
|
||||
}
|
||||
std::io::ErrorKind::OutOfMemory => {
|
||||
Report::new(e).attach_printable(ErrorReason::OutOfMemory)
|
||||
}
|
||||
std::io::ErrorKind::StorageFull => {
|
||||
Report::new(e).attach_printable(ErrorReason::OutOfStorage)
|
||||
}
|
||||
_ => Report::new(e).attach_printable(ErrorReason::ImageWriteOtherError),
|
||||
})
|
||||
.change_context(NdCvError)
|
||||
}
|
||||
|
||||
fn transform(&self) -> Result<<E as Encoder>::Input<'_>, NdCvError>;
|
||||
}
|
||||
|
||||
pub trait Encoder {
|
||||
type Input<'a>
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
fn encode(&self, input: Self::Input<'_>) -> Result<Vec<u8>, NdCvError>;
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod + Copy, S: ndarray::Data<Elem = T>, D: ndarray::Dimension>
|
||||
Encodable<CvEncoder> for ArrayBase<S, D>
|
||||
where
|
||||
Self: NdAsImage<T, D>,
|
||||
{
|
||||
fn transform(&self) -> Result<<CvEncoder as Encoder>::Input<'_>, NdCvError> {
|
||||
self.as_image_mat()
|
||||
}
|
||||
}
|
||||
19
ndcv-bridge/src/codec/error.rs
Normal file
19
ndcv-bridge/src/codec/error.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
#[derive(Debug)]
|
||||
pub enum ErrorReason {
|
||||
ImageReadFileNotFound,
|
||||
ImageReadPermissionDenied,
|
||||
ImageReadOtherError,
|
||||
|
||||
ImageWriteFileNotFound,
|
||||
ImageWritePermissionDenied,
|
||||
ImageWriteOtherError,
|
||||
|
||||
OutOfMemory,
|
||||
OutOfStorage,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ErrorReason {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
88
ndcv-bridge/src/color_space.rs
Normal file
88
ndcv-bridge/src/color_space.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
//! Colorspace conversion functions
|
||||
//! ## Example
|
||||
//! ```rust
|
||||
//! let arr = Array3::<u8>::ones((100, 100, 3));
|
||||
//! let out: Array3<u8> = arr.cvt::<Rgba<u8>, Rgb<u8>>()
|
||||
//! ```
|
||||
use crate::prelude_::*;
|
||||
use ndarray::*;
|
||||
|
||||
pub trait ColorSpace {
|
||||
type Elem: seal::Sealed;
|
||||
type Dim: ndarray::Dimension;
|
||||
const CHANNELS: usize;
|
||||
}
|
||||
|
||||
mod seal {
|
||||
pub trait Sealed: bytemuck::Pod {}
|
||||
// impl<T> Sealed for T {}
|
||||
impl Sealed for u8 {} // 0 to 255
|
||||
impl Sealed for u16 {} // 0 to 65535
|
||||
impl Sealed for f32 {} // 0 to 1
|
||||
}
|
||||
|
||||
macro_rules! define_color_space {
|
||||
($name:ident, $channels:expr, $depth:ty) => {
|
||||
pub struct $name<T> {
|
||||
__phantom: core::marker::PhantomData<T>,
|
||||
}
|
||||
impl<T: seal::Sealed> ColorSpace for $name<T> {
|
||||
type Elem = T;
|
||||
type Dim = $depth;
|
||||
const CHANNELS: usize = $channels;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
define_color_space!(Rgb, 3, Ix3);
|
||||
define_color_space!(Bgr, 3, Ix3);
|
||||
define_color_space!(Rgba, 4, Ix3);
|
||||
|
||||
pub trait NdArray<T, D: ndarray::Dimension> {}
|
||||
impl<T, D: ndarray::Dimension, S: ndarray::Data<Elem = T>> NdArray<S, D> for ArrayBase<S, D> {}
|
||||
|
||||
pub trait ConvertColor<T, U>
|
||||
where
|
||||
T: ColorSpace,
|
||||
U: ColorSpace,
|
||||
Self: NdArray<T::Elem, T::Dim>,
|
||||
{
|
||||
type Output: NdArray<U::Elem, U::Dim>;
|
||||
fn cvt(&self) -> Self::Output;
|
||||
}
|
||||
|
||||
// impl<T: seal::Sealed, S: ndarray::Data<Elem = T>> ConvertColor<Rgb<T>, Bgr<T>> for ArrayBase<S, Ix3>
|
||||
// where
|
||||
// Self: NdArray<T, Ix3>,
|
||||
// {
|
||||
// type Output = ArrayView3<'a, T>;
|
||||
// fn cvt(&self) -> CowArray<T, Ix3> {
|
||||
// self.view().permuted_axes([2, 1, 0]).into()
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// impl<T: seal::Sealed, S: ndarray::Data<Elem = T>> ConvertColor<Bgr<T>, Rgb<T>> for ArrayBase<S, Ix3>
|
||||
// where
|
||||
// Self: NdArray<T, Ix3>,
|
||||
// {
|
||||
// type Output = ArrayView3<'a, T>;
|
||||
// fn cvt(&self) -> CowArray<T, Ix3> {
|
||||
// self.view().permuted_axes([2, 1, 0]).into()
|
||||
// }
|
||||
// }
|
||||
|
||||
|
||||
// impl<T: seal::Sealed + num::One + num::Zero, S: ndarray::Data<Elem = T>>
|
||||
// ConvertColor<Rgb<T>, Rgba<T>> for ArrayBase<S, Ix3>
|
||||
// {
|
||||
// fn cvt(&self) -> CowArray<T, Ix3> {
|
||||
// let mut out = Array3::<T>::zeros((self.height(), self.width(), 4));
|
||||
// // Zip::from(&mut out).and(self).for_each(|out, &in_| {
|
||||
// // out[0] = in_[0];
|
||||
// // out[1] = in_[1];
|
||||
// // out[2] = in_[2];
|
||||
// // out[3] = T::one();
|
||||
// // });
|
||||
// out.into()
|
||||
// }
|
||||
// }
|
||||
113
ndcv-bridge/src/connected_components.rs
Normal file
113
ndcv-bridge/src/connected_components.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
use crate::{NdAsImage, NdAsImageMut, conversions::MatAsNd, prelude_::*};
|
||||
|
||||
pub(crate) mod seal {
|
||||
pub trait ConnectedComponentOutput: Sized + Copy + bytemuck::Pod + num::Zero {
|
||||
fn as_cv_type() -> i32 {
|
||||
crate::type_depth::<Self>()
|
||||
}
|
||||
}
|
||||
impl ConnectedComponentOutput for i32 {}
|
||||
impl ConnectedComponentOutput for u16 {}
|
||||
}
|
||||
|
||||
pub trait NdCvConnectedComponents<T> {
|
||||
fn connected_components<O: seal::ConnectedComponentOutput>(
|
||||
&self,
|
||||
connectivity: Connectivity,
|
||||
) -> Result<ndarray::Array2<O>, NdCvError>;
|
||||
fn connected_components_with_stats<O: seal::ConnectedComponentOutput>(
|
||||
&self,
|
||||
connectivity: Connectivity,
|
||||
) -> Result<ConnectedComponentStats<O>, NdCvError>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum Connectivity {
|
||||
Four = 4,
|
||||
#[default]
|
||||
Eight = 8,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnectedComponentStats<O: seal::ConnectedComponentOutput> {
|
||||
pub num_labels: i32,
|
||||
pub labels: ndarray::Array2<O>,
|
||||
pub stats: ndarray::Array2<i32>,
|
||||
pub centroids: ndarray::Array2<f64>,
|
||||
}
|
||||
|
||||
// use crate::conversions::NdCvConversionRef;
|
||||
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>> NdCvConnectedComponents<T>
|
||||
for ndarray::ArrayBase<S, ndarray::Ix2>
|
||||
where
|
||||
ndarray::Array2<T>: NdAsImage<T, ndarray::Ix2>,
|
||||
{
|
||||
fn connected_components<O: seal::ConnectedComponentOutput>(
|
||||
&self,
|
||||
connectivity: Connectivity,
|
||||
) -> Result<ndarray::Array2<O>, NdCvError> {
|
||||
let mat = self.as_image_mat()?;
|
||||
let mut labels = ndarray::Array2::<O>::zeros(self.dim());
|
||||
let mut cv_labels = labels.as_image_mat_mut()?;
|
||||
opencv::imgproc::connected_components(
|
||||
mat.as_ref(),
|
||||
cv_labels.as_mut(),
|
||||
connectivity as i32,
|
||||
O::as_cv_type(),
|
||||
)
|
||||
.change_context(NdCvError)?;
|
||||
Ok(labels)
|
||||
}
|
||||
|
||||
fn connected_components_with_stats<O: seal::ConnectedComponentOutput>(
|
||||
&self,
|
||||
connectivity: Connectivity,
|
||||
) -> Result<ConnectedComponentStats<O>, NdCvError> {
|
||||
let mut labels = ndarray::Array2::<O>::zeros(self.dim());
|
||||
let mut stats = opencv::core::Mat::default();
|
||||
let mut centroids = opencv::core::Mat::default();
|
||||
let num_labels = opencv::imgproc::connected_components_with_stats(
|
||||
self.as_image_mat()?.as_ref(),
|
||||
labels.as_image_mat_mut()?.as_mut(),
|
||||
&mut stats,
|
||||
&mut centroids,
|
||||
connectivity as i32,
|
||||
O::as_cv_type(),
|
||||
)
|
||||
.change_context(NdCvError)?;
|
||||
let stats = stats.as_ndarray()?.to_owned();
|
||||
let centroids = centroids.as_ndarray()?.to_owned();
|
||||
Ok(ConnectedComponentStats {
|
||||
labels,
|
||||
stats,
|
||||
centroids,
|
||||
num_labels,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn test_connected_components() {
|
||||
// use opencv::core::MatTrait as _;
|
||||
// let mat = opencv::core::Mat::new_nd_with_default(&[10, 10], opencv::core::CV_8UC1, 0.into())
|
||||
// .expect("failed");
|
||||
// let roi1 = opencv::core::Rect::new(2, 2, 2, 2);
|
||||
// let roi2 = opencv::core::Rect::new(6, 6, 3, 3);
|
||||
// let mut mat1 = opencv::core::Mat::roi(&mat, roi1).expect("failed");
|
||||
// mat1.set_scalar(1.into()).expect("failed");
|
||||
// let mut mat2 = opencv::core::Mat::roi(&mat, roi2).expect("failed");
|
||||
// mat2.set_scalar(1.into()).expect("failed");
|
||||
|
||||
// let array2: ndarray::ArrayView2<u8> = mat.as_ndarray().expect("failed");
|
||||
// let output = array2
|
||||
// .connected_components::<u16>(Connectivity::Four)
|
||||
// .expect("failed");
|
||||
// let expected = {
|
||||
// let mut expected = ndarray::Array2::zeros((10, 10));
|
||||
// expected.slice_mut(ndarray::s![2..4, 2..4]).fill(1);
|
||||
// expected.slice_mut(ndarray::s![6..9, 6..9]).fill(2);
|
||||
// expected
|
||||
// };
|
||||
|
||||
// assert_eq!(output, expected);
|
||||
// }
|
||||
270
ndcv-bridge/src/contours.rs
Normal file
270
ndcv-bridge/src/contours.rs
Normal file
@@ -0,0 +1,270 @@
|
||||
//! <https://docs.rs/opencv/latest/opencv/imgproc/fn.find_contours.html>
|
||||
|
||||
#![deny(warnings)]
|
||||
|
||||
use crate::conversions::*;
|
||||
use crate::prelude_::*;
|
||||
use nalgebra::Point2;
|
||||
use ndarray::*;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum ContourRetrievalMode {
|
||||
#[default]
|
||||
External = 0, // RETR_EXTERNAL
|
||||
List = 1, // RETR_LIST
|
||||
CComp = 2, // RETR_CCOMP
|
||||
Tree = 3, // RETR_TREE
|
||||
FloodFill = 4, // RETR_FLOODFILL
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum ContourApproximationMethod {
|
||||
#[default]
|
||||
None = 1, // CHAIN_APPROX_NONE
|
||||
Simple = 2, // CHAIN_APPROX_SIMPLE
|
||||
Tc89L1 = 3, // CHAIN_APPROX_TC89_L1
|
||||
Tc89Kcos = 4, // CHAIN_APPROX_TC89_KCOS
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContourHierarchy {
|
||||
pub next: i32,
|
||||
pub previous: i32,
|
||||
pub first_child: i32,
|
||||
pub parent: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContourResult {
|
||||
pub contours: Vec<Vec<Point2<i32>>>,
|
||||
pub hierarchy: Vec<ContourHierarchy>,
|
||||
}
|
||||
|
||||
mod seal {
|
||||
pub trait Sealed {}
|
||||
impl Sealed for u8 {}
|
||||
}
|
||||
|
||||
pub trait NdCvFindContours<T: bytemuck::Pod + seal::Sealed>:
|
||||
crate::image::NdImage + crate::conversions::NdAsImage<T, ndarray::Ix2>
|
||||
{
|
||||
fn find_contours(
|
||||
&self,
|
||||
mode: ContourRetrievalMode,
|
||||
method: ContourApproximationMethod,
|
||||
) -> Result<Vec<Vec<Point2<i32>>>, NdCvError>;
|
||||
|
||||
fn find_contours_with_hierarchy(
|
||||
&self,
|
||||
mode: ContourRetrievalMode,
|
||||
method: ContourApproximationMethod,
|
||||
) -> Result<ContourResult, NdCvError>;
|
||||
|
||||
fn find_contours_def(&self) -> Result<Vec<Vec<Point2<i32>>>, NdCvError> {
|
||||
self.find_contours(
|
||||
ContourRetrievalMode::External,
|
||||
ContourApproximationMethod::Simple,
|
||||
)
|
||||
}
|
||||
|
||||
fn find_contours_with_hierarchy_def(&self) -> Result<ContourResult, NdCvError> {
|
||||
self.find_contours_with_hierarchy(
|
||||
ContourRetrievalMode::External,
|
||||
ContourApproximationMethod::Simple,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait NdCvContourArea<T: bytemuck::Pod> {
|
||||
fn contours_area(&self, oriented: bool) -> Result<f64, NdCvError>;
|
||||
|
||||
fn contours_area_def(&self) -> Result<f64, NdCvError> {
|
||||
self.contours_area(false)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ndarray::RawData + ndarray::Data<Elem = u8>> NdCvFindContours<u8> for ArrayBase<T, Ix2> {
|
||||
fn find_contours(
|
||||
&self,
|
||||
mode: ContourRetrievalMode,
|
||||
method: ContourApproximationMethod,
|
||||
) -> Result<Vec<Vec<Point2<i32>>>, NdCvError> {
|
||||
let cv_self = self.as_image_mat()?;
|
||||
let mut contours = opencv::core::Vector::<opencv::core::Vector<opencv::core::Point>>::new();
|
||||
|
||||
opencv::imgproc::find_contours(
|
||||
&*cv_self,
|
||||
&mut contours,
|
||||
mode as i32,
|
||||
method as i32,
|
||||
opencv::core::Point::new(0, 0),
|
||||
)
|
||||
.change_context(NdCvError)
|
||||
.attach_printable("Failed to find contours")?;
|
||||
let mut result: Vec<Vec<Point2<i32>>> = Vec::new();
|
||||
|
||||
for i in 0..contours.len() {
|
||||
let contour = contours.get(i).change_context(NdCvError)?;
|
||||
let points: Vec<Point2<i32>> =
|
||||
contour.iter().map(|pt| Point2::new(pt.x, pt.y)).collect();
|
||||
result.push(points);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn find_contours_with_hierarchy(
|
||||
&self,
|
||||
mode: ContourRetrievalMode,
|
||||
method: ContourApproximationMethod,
|
||||
) -> Result<ContourResult, NdCvError> {
|
||||
let cv_self = self.as_image_mat()?;
|
||||
let mut contours = opencv::core::Vector::<opencv::core::Vector<opencv::core::Point>>::new();
|
||||
let mut hierarchy = opencv::core::Vector::<opencv::core::Vec4i>::new();
|
||||
|
||||
opencv::imgproc::find_contours_with_hierarchy(
|
||||
&*cv_self,
|
||||
&mut contours,
|
||||
&mut hierarchy,
|
||||
mode as i32,
|
||||
method as i32,
|
||||
opencv::core::Point::new(0, 0),
|
||||
)
|
||||
.change_context(NdCvError)
|
||||
.attach_printable("Failed to find contours with hierarchy")?;
|
||||
let mut contour_list: Vec<Vec<Point2<i32>>> = Vec::new();
|
||||
|
||||
for i in 0..contours.len() {
|
||||
let contour = contours.get(i).change_context(NdCvError)?;
|
||||
let points: Vec<Point2<i32>> =
|
||||
contour.iter().map(|pt| Point2::new(pt.x, pt.y)).collect();
|
||||
contour_list.push(points);
|
||||
}
|
||||
|
||||
let mut hierarchy_list = Vec::new();
|
||||
for i in 0..hierarchy.len() {
|
||||
let h = hierarchy.get(i).change_context(NdCvError)?;
|
||||
hierarchy_list.push(ContourHierarchy {
|
||||
next: h[0],
|
||||
previous: h[1],
|
||||
first_child: h[2],
|
||||
parent: h[3],
|
||||
});
|
||||
}
|
||||
|
||||
Ok(ContourResult {
|
||||
contours: contour_list,
|
||||
hierarchy: hierarchy_list,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> NdCvContourArea<T> for Vec<Point2<T>>
|
||||
where
|
||||
T: bytemuck::Pod + num::traits::AsPrimitive<i32> + std::cmp::PartialEq + std::fmt::Debug + Copy,
|
||||
{
|
||||
fn contours_area(&self, oriented: bool) -> Result<f64, NdCvError> {
|
||||
if self.is_empty() {
|
||||
return Ok(0.0);
|
||||
}
|
||||
|
||||
let mut cv_contour: opencv::core::Vector<opencv::core::Point> = opencv::core::Vector::new();
|
||||
self.iter().for_each(|point| {
|
||||
cv_contour.push(opencv::core::Point::new(
|
||||
point.coords[0].as_(),
|
||||
point.coords[1].as_(),
|
||||
));
|
||||
});
|
||||
|
||||
opencv::imgproc::contour_area(&cv_contour, oriented)
|
||||
.change_context(NdCvError)
|
||||
.attach_printable("Failed to calculate contour area")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ndarray::Array2;
|
||||
|
||||
fn simple_binary_rect_image() -> Array2<u8> {
|
||||
let mut img = Array2::<u8>::zeros((10, 10));
|
||||
for i in 2..8 {
|
||||
for j in 3..7 {
|
||||
img[(i, j)] = 255;
|
||||
}
|
||||
}
|
||||
img
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_contours_external_simple() {
|
||||
let img = simple_binary_rect_image();
|
||||
let contours = img
|
||||
.find_contours(
|
||||
ContourRetrievalMode::External,
|
||||
ContourApproximationMethod::Simple,
|
||||
)
|
||||
.expect("Failed to find contours");
|
||||
assert_eq!(contours.len(), 1);
|
||||
assert!(contours[0].len() >= 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_contours_with_hierarchy() {
|
||||
let img = simple_binary_rect_image();
|
||||
let res = img
|
||||
.find_contours_with_hierarchy(
|
||||
ContourRetrievalMode::External,
|
||||
ContourApproximationMethod::Simple,
|
||||
)
|
||||
.expect("Failed to find contours with hierarchy");
|
||||
assert_eq!(res.contours.len(), 1);
|
||||
assert_eq!(res.hierarchy.len(), 1);
|
||||
|
||||
let h = &res.hierarchy[0];
|
||||
assert_eq!(h.parent, -1);
|
||||
assert_eq!(h.first_child, -1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_methods() {
|
||||
let img = simple_binary_rect_image();
|
||||
let contours = img.find_contours_def().unwrap();
|
||||
let res = img.find_contours_with_hierarchy_def().unwrap();
|
||||
assert_eq!(contours.len(), 1);
|
||||
assert_eq!(res.contours.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contour_area_calculation() {
|
||||
let img = simple_binary_rect_image();
|
||||
let contours = img.find_contours_def().unwrap();
|
||||
let expected_area = 15.;
|
||||
let area = contours[0].contours_area_def().unwrap();
|
||||
assert!(
|
||||
(area - expected_area).abs() < 1.0,
|
||||
"Area mismatch: got {area}, expected {expected_area}",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_input_returns_no_contours() {
|
||||
let img = Array2::<u8>::zeros((10, 10));
|
||||
let contours = img.find_contours_def().unwrap();
|
||||
assert!(contours.is_empty());
|
||||
|
||||
let res = img.find_contours_with_hierarchy_def().unwrap();
|
||||
assert!(res.contours.is_empty());
|
||||
assert!(res.hierarchy.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contour_area_empty_contour() {
|
||||
let contour: Vec<Point2<i32>> = vec![];
|
||||
let area = contour.contours_area_def().unwrap();
|
||||
assert_eq!(area, 0.0);
|
||||
}
|
||||
}
|
||||
337
ndcv-bridge/src/conversions.rs
Normal file
337
ndcv-bridge/src/conversions.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
//! Mat <--> ndarray conversion traits
|
||||
//!
|
||||
//! Conversion Table
|
||||
//!
|
||||
//! | ndarray | Mat |
|
||||
//! |--------- |----- |
|
||||
//! | Array<T, Ix1> | Mat(ndims = 1, channels = 1) |
|
||||
//! | Array<T, Ix2> | Mat(ndims = 2, channels = 1) |
|
||||
//! | Array<T, Ix2> | Mat(ndims = 1, channels = X) |
|
||||
//! | Array<T, Ix3> | Mat(ndims = 3, channels = 1) |
|
||||
//! | Array<T, Ix3> | Mat(ndims = 2, channels = X) |
|
||||
//! | Array<T, Ix4> | Mat(ndims = 4, channels = 1) |
|
||||
//! | Array<T, Ix4> | Mat(ndims = 3, channels = X) |
|
||||
//! | Array<T, Ix5> | Mat(ndims = 5, channels = 1) |
|
||||
//! | Array<T, Ix5> | Mat(ndims = 4, channels = X) |
|
||||
//! | Array<T, Ix6> | Mat(ndims = 6, channels = 1) |
|
||||
//! | Array<T, Ix6> | Mat(ndims = 5, channels = X) |
|
||||
//!
|
||||
//! // X is the last dimension
|
||||
use crate::NdCvError;
|
||||
use crate::type_depth;
|
||||
use error_stack::*;
|
||||
use ndarray::{Ix2, Ix3};
|
||||
use opencv::core::MatTraitConst;
|
||||
mod impls;
|
||||
pub(crate) mod matref;
|
||||
use matref::{MatRef, MatRefMut};
|
||||
|
||||
pub(crate) mod seal {
|
||||
pub trait SealedInternal {}
|
||||
impl<T, S: ndarray::Data<Elem = T>, D> SealedInternal for ndarray::ArrayBase<S, D> {}
|
||||
// impl<T, S: ndarray::DataMut<Elem = T>, D> SealedInternal for ndarray::ArrayBase<S, D> {}
|
||||
}
|
||||
|
||||
pub trait NdCvConversion<T: bytemuck::Pod + Copy, D: ndarray::Dimension>:
|
||||
seal::SealedInternal + Sized
|
||||
{
|
||||
fn to_mat(&self) -> Result<opencv::core::Mat, NdCvError>;
|
||||
fn from_mat(
|
||||
mat: opencv::core::Mat,
|
||||
) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<T>, D>, NdCvError>;
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod + Copy, S: ndarray::Data<Elem = T>, D: ndarray::Dimension>
|
||||
NdCvConversion<T, D> for ndarray::ArrayBase<S, D>
|
||||
where
|
||||
Self: NdAsImage<T, D>,
|
||||
{
|
||||
fn to_mat(&self) -> Result<opencv::core::Mat, NdCvError> {
|
||||
Ok(self.as_image_mat()?.mat.clone())
|
||||
}
|
||||
|
||||
fn from_mat(
|
||||
mat: opencv::core::Mat,
|
||||
) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<T>, D>, NdCvError> {
|
||||
let ndarray = unsafe { impls::mat_to_ndarray::<T, D>(&mat) }.change_context(NdCvError)?;
|
||||
Ok(ndarray.to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait MatAsNd {
|
||||
fn as_ndarray<T: bytemuck::Pod, D: ndarray::Dimension>(
|
||||
&self,
|
||||
) -> Result<ndarray::ArrayView<T, D>, NdCvError>;
|
||||
}
|
||||
|
||||
impl MatAsNd for opencv::core::Mat {
|
||||
fn as_ndarray<T: bytemuck::Pod, D: ndarray::Dimension>(
|
||||
&self,
|
||||
) -> Result<ndarray::ArrayView<T, D>, NdCvError> {
|
||||
unsafe { impls::mat_to_ndarray::<T, D>(self) }.change_context(NdCvError)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait NdAsMat<T: bytemuck::Pod + Copy, D: ndarray::Dimension> {
|
||||
fn as_single_channel_mat(&self) -> Result<MatRef, NdCvError>;
|
||||
fn as_multi_channel_mat(&self) -> Result<MatRef, NdCvError>;
|
||||
}
|
||||
|
||||
pub trait NdAsMatMut<T: bytemuck::Pod + Copy, D: ndarray::Dimension>: NdAsMat<T, D> {
|
||||
fn as_single_channel_mat_mut(&mut self) -> Result<MatRefMut, NdCvError>;
|
||||
fn as_multi_channel_mat_mut(&mut self) -> Result<MatRefMut, NdCvError>;
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>, D: ndarray::Dimension> NdAsMat<T, D>
|
||||
for ndarray::ArrayBase<S, D>
|
||||
{
|
||||
fn as_single_channel_mat(&self) -> Result<MatRef, NdCvError> {
|
||||
let mat = unsafe { impls::ndarray_to_mat_regular(self) }.change_context(NdCvError)?;
|
||||
Ok(MatRef::new(mat))
|
||||
}
|
||||
fn as_multi_channel_mat(&self) -> Result<MatRef, NdCvError> {
|
||||
let mat = unsafe { impls::ndarray_to_mat_consolidated(self) }.change_context(NdCvError)?;
|
||||
Ok(MatRef::new(mat))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod, S: ndarray::DataMut<Elem = T>, D: ndarray::Dimension> NdAsMatMut<T, D>
|
||||
for ndarray::ArrayBase<S, D>
|
||||
{
|
||||
fn as_single_channel_mat_mut(&mut self) -> Result<MatRefMut, NdCvError> {
|
||||
let mat = unsafe { impls::ndarray_to_mat_regular(self) }.change_context(NdCvError)?;
|
||||
Ok(MatRefMut::new(mat))
|
||||
}
|
||||
|
||||
fn as_multi_channel_mat_mut(&mut self) -> Result<MatRefMut, NdCvError> {
|
||||
let mat = unsafe { impls::ndarray_to_mat_consolidated(self) }.change_context(NdCvError)?;
|
||||
Ok(MatRefMut::new(mat))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait NdAsImage<T: bytemuck::Pod, D: ndarray::Dimension> {
|
||||
fn as_image_mat(&self) -> Result<MatRef, NdCvError>;
|
||||
}
|
||||
|
||||
pub trait NdAsImageMut<T: bytemuck::Pod, D: ndarray::Dimension> {
|
||||
fn as_image_mat_mut(&mut self) -> Result<MatRefMut, NdCvError>;
|
||||
}
|
||||
|
||||
impl<T, S> NdAsImage<T, Ix2> for ndarray::ArrayBase<S, Ix2>
|
||||
where
|
||||
T: bytemuck::Pod + Copy,
|
||||
S: ndarray::Data<Elem = T>,
|
||||
{
|
||||
fn as_image_mat(&self) -> Result<MatRef, NdCvError> {
|
||||
self.as_single_channel_mat()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> NdAsImageMut<T, Ix2> for ndarray::ArrayBase<S, Ix2>
|
||||
where
|
||||
T: bytemuck::Pod + Copy,
|
||||
S: ndarray::DataMut<Elem = T>,
|
||||
{
|
||||
fn as_image_mat_mut(&mut self) -> Result<MatRefMut, NdCvError> {
|
||||
self.as_single_channel_mat_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> NdAsImage<T, Ix3> for ndarray::ArrayBase<S, Ix3>
|
||||
where
|
||||
T: bytemuck::Pod + Copy,
|
||||
S: ndarray::Data<Elem = T>,
|
||||
{
|
||||
fn as_image_mat(&self) -> Result<MatRef, NdCvError> {
|
||||
self.as_multi_channel_mat()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> NdAsImageMut<T, Ix3> for ndarray::ArrayBase<S, Ix3>
|
||||
where
|
||||
T: bytemuck::Pod + Copy,
|
||||
S: ndarray::DataMut<Elem = T>,
|
||||
{
|
||||
fn as_image_mat_mut(&mut self) -> Result<MatRefMut, NdCvError> {
|
||||
self.as_multi_channel_mat_mut()
|
||||
}
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn test_1d_mat_to_ndarray() {
|
||||
// let mat = opencv::core::Mat::new_nd_with_default(
|
||||
// &[10],
|
||||
// opencv::core::CV_MAKE_TYPE(opencv::core::CV_8U, 1),
|
||||
// 200.into(),
|
||||
// )
|
||||
// .expect("failed");
|
||||
// let array: ndarray::ArrayView1<u8> = mat.as_ndarray().expect("failed");
|
||||
// array.into_iter().for_each(|&x| assert_eq!(x, 200));
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_2d_mat_to_ndarray() {
|
||||
// let mat = opencv::core::Mat::new_nd_with_default(
|
||||
// &[10],
|
||||
// opencv::core::CV_16SC3,
|
||||
// (200, 200, 200).into(),
|
||||
// )
|
||||
// .expect("failed");
|
||||
// let array2: ndarray::ArrayView2<i16> = mat.as_ndarray().expect("failed");
|
||||
// assert_eq!(array2.shape(), [10, 3]);
|
||||
// array2.into_iter().for_each(|&x| {
|
||||
// assert_eq!(x, 200);
|
||||
// });
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_3d_mat_to_ndarray() {
|
||||
// let mat = opencv::core::Mat::new_nd_with_default(
|
||||
// &[20, 30],
|
||||
// opencv::core::CV_32FC3,
|
||||
// (200, 200, 200).into(),
|
||||
// )
|
||||
// .expect("failed");
|
||||
// let array2: ndarray::ArrayView3<f32> = mat.as_ndarray().expect("failed");
|
||||
// array2.into_iter().for_each(|&x| {
|
||||
// assert_eq!(x, 200f32);
|
||||
// });
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_mat_to_dyn_ndarray() {
|
||||
// let mat = opencv::core::Mat::new_nd_with_default(&[10], opencv::core::CV_8UC1, 200.into())
|
||||
// .expect("failed");
|
||||
// let array2: ndarray::ArrayViewD<u8> = mat.as_ndarray().expect("failed");
|
||||
// array2.into_iter().for_each(|&x| assert_eq!(x, 200));
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_3d_mat_to_ndarray_4k() {
|
||||
// let mat = opencv::core::Mat::new_nd_with_default(
|
||||
// &[4096, 4096],
|
||||
// opencv::core::CV_8UC3,
|
||||
// (255, 0, 255).into(),
|
||||
// )
|
||||
// .expect("failed");
|
||||
// let array2: ndarray::ArrayView3<u8> = (mat).as_ndarray().expect("failed");
|
||||
// array2.exact_chunks((1, 1, 3)).into_iter().for_each(|x| {
|
||||
// assert_eq!(x[(0, 0, 0)], 255);
|
||||
// assert_eq!(x[(0, 0, 1)], 0);
|
||||
// assert_eq!(x[(0, 0, 2)], 255);
|
||||
// });
|
||||
// }
|
||||
|
||||
// // #[test]
|
||||
// // fn test_3d_mat_to_ndarray_8k() {
|
||||
// // let mat = opencv::core::Mat::new_nd_with_default(
|
||||
// // &[8192, 8192],
|
||||
// // opencv::core::CV_8UC3,
|
||||
// // (255, 0, 255).into(),
|
||||
// // )
|
||||
// // .expect("failed");
|
||||
// // let array2 = ndarray::Array3::<u8>::from_mat(mat).expect("failed");
|
||||
// // array2.exact_chunks((1, 1, 3)).into_iter().for_each(|x| {
|
||||
// // assert_eq!(x[(0, 0, 0)], 255);
|
||||
// // assert_eq!(x[(0, 0, 1)], 0);
|
||||
// // assert_eq!(x[(0, 0, 2)], 255);
|
||||
// // });
|
||||
// // }
|
||||
|
||||
// #[test]
|
||||
// pub fn test_mat_to_nd_default_strides() {
|
||||
// let mat = opencv::core::Mat::new_rows_cols_with_default(
|
||||
// 10,
|
||||
// 10,
|
||||
// opencv::core::CV_8UC3,
|
||||
// opencv::core::VecN([10f64, 0.0, 0.0, 0.0]),
|
||||
// )
|
||||
// .expect("failed");
|
||||
// let array = unsafe { impls::mat_to_ndarray::<u8, Ix3>(&mat) }.expect("failed");
|
||||
// assert_eq!(array.shape(), [10, 10, 3]);
|
||||
// assert_eq!(array.strides(), [30, 3, 1]);
|
||||
// assert_eq!(array[(0, 0, 0)], 10);
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// pub fn test_mat_to_nd_custom_strides() {
|
||||
// let mat = opencv::core::Mat::new_rows_cols_with_default(
|
||||
// 10,
|
||||
// 10,
|
||||
// opencv::core::CV_8UC3,
|
||||
// opencv::core::VecN([10f64, 0.0, 0.0, 0.0]),
|
||||
// )
|
||||
// .unwrap();
|
||||
// let mat_roi = opencv::core::Mat::roi(&mat, opencv::core::Rect::new(3, 2, 3, 5))
|
||||
// .expect("failed to get roi");
|
||||
// let array = unsafe { impls::mat_to_ndarray::<u8, Ix3>(&mat_roi) }.expect("failed");
|
||||
// assert_eq!(array.shape(), [5, 3, 3]);
|
||||
// assert_eq!(array.strides(), [30, 3, 1]);
|
||||
// assert_eq!(array[(0, 0, 0)], 10);
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// pub fn test_non_continuous_3d() {
|
||||
// let array = ndarray::Array3::<f32>::from_shape_fn((10, 10, 4), |(i, j, k)| {
|
||||
// ((i + 1) * (j + 1) * (k + 1)) as f32
|
||||
// });
|
||||
// let slice = array.slice(ndarray::s![3..7, 3..7, 0..4]);
|
||||
// let mat = unsafe { impls::ndarray_to_mat_consolidated(&slice) }.unwrap();
|
||||
// let arr = unsafe { impls::mat_to_ndarray::<f32, Ix3>(&mat).unwrap() };
|
||||
// assert!(slice == arr);
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// pub fn test_5d_array() {
|
||||
// let array = ndarray::Array5::<f32>::ones((1, 2, 3, 4, 5));
|
||||
// let mat = unsafe { impls::ndarray_to_mat_consolidated(&array) }.unwrap();
|
||||
// let arr = unsafe { impls::mat_to_ndarray::<f32, ndarray::Ix5>(&mat).unwrap() };
|
||||
// assert_eq!(array, arr);
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// pub fn test_3d_array() {
|
||||
// let array = ndarray::Array3::<f32>::ones((23, 31, 33));
|
||||
// let mat = unsafe { impls::ndarray_to_mat_consolidated(&array) }.unwrap();
|
||||
// let arr = unsafe { impls::mat_to_ndarray::<f32, ndarray::Ix3>(&mat).unwrap() };
|
||||
// assert_eq!(array, arr);
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// pub fn test_2d_array() {
|
||||
// let array = ndarray::Array2::<f32>::ones((23, 31));
|
||||
// let mat = unsafe { impls::ndarray_to_mat_consolidated(&array) }.unwrap();
|
||||
// let arr = unsafe { impls::mat_to_ndarray::<f32, ndarray::Ix2>(&mat).unwrap() };
|
||||
// assert_eq!(array, arr);
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// #[should_panic]
|
||||
// pub fn test_1d_array_consolidated() {
|
||||
// let array = ndarray::Array1::<f32>::ones(23);
|
||||
// let mat = unsafe { impls::ndarray_to_mat_consolidated(&array) }.unwrap();
|
||||
// let arr = unsafe { impls::mat_to_ndarray::<f32, ndarray::Ix1>(&mat).unwrap() };
|
||||
// assert_eq!(array, arr);
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// pub fn test_1d_array_regular() {
|
||||
// let array = ndarray::Array1::<f32>::ones(23);
|
||||
// let mat = unsafe { impls::ndarray_to_mat_regular(&array) }.unwrap();
|
||||
// let arr = unsafe { impls::mat_to_ndarray::<f32, ndarray::Ix1>(&mat).unwrap() };
|
||||
// assert_eq!(array, arr);
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// pub fn test_2d_array_regular() {
|
||||
// let array = ndarray::Array2::<f32>::ones((23, 31));
|
||||
// let mat = unsafe { impls::ndarray_to_mat_regular(&array) }.unwrap();
|
||||
// let arr = unsafe { impls::mat_to_ndarray::<f32, ndarray::Ix2>(&mat).unwrap() };
|
||||
// assert_eq!(array, arr);
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// pub fn test_ndcv_1024_1024_to_mat() {
|
||||
// let array = ndarray::Array2::<f32>::ones((1024, 1024));
|
||||
// let _mat = array.to_mat().unwrap();
|
||||
// }
|
||||
168
ndcv-bridge/src/conversions/impls.rs
Normal file
168
ndcv-bridge/src/conversions/impls.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
use super::*;
|
||||
use core::ffi::*;
|
||||
use opencv::core::prelude::*;
|
||||
pub(crate) unsafe fn ndarray_to_mat_regular<
|
||||
T,
|
||||
S: ndarray::Data<Elem = T>,
|
||||
D: ndarray::Dimension,
|
||||
>(
|
||||
input: &ndarray::ArrayBase<S, D>,
|
||||
) -> Result<opencv::core::Mat, NdCvError> {
|
||||
let shape = input.shape();
|
||||
let strides = input.strides();
|
||||
|
||||
// let channels = shape.last().copied().unwrap_or(1);
|
||||
// if channels > opencv::core::CV_CN_MAX as usize {
|
||||
// Err(Report::new(NdCvError).attach_printable(format!(
|
||||
// "Number of channels({channels}) exceeds CV_CN_MAX({}) use the regular version of the function", opencv::core::CV_CN_MAX
|
||||
// )))?;
|
||||
// }
|
||||
|
||||
// let size_len = shape.len();
|
||||
let size = shape.iter().copied().map(|f| f as i32).collect::<Vec<_>>();
|
||||
// Step len for ndarray is always 1 less than ndims
|
||||
let step_len = strides.len() - 1;
|
||||
let step = strides
|
||||
.iter()
|
||||
.take(step_len)
|
||||
.copied()
|
||||
.map(|f| f as usize * core::mem::size_of::<T>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let data_ptr = input.as_ptr() as *const c_void;
|
||||
|
||||
let typ = opencv::core::CV_MAKETYPE(type_depth::<T>(), 1);
|
||||
let mat = opencv::core::Mat::new_nd_with_data_unsafe(
|
||||
size.as_slice(),
|
||||
typ,
|
||||
data_ptr.cast_mut(),
|
||||
Some(step.as_slice()),
|
||||
)
|
||||
.change_context(NdCvError)?;
|
||||
|
||||
Ok(mat)
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn ndarray_to_mat_consolidated<
|
||||
T,
|
||||
S: ndarray::Data<Elem = T>,
|
||||
D: ndarray::Dimension,
|
||||
>(
|
||||
input: &ndarray::ArrayBase<S, D>,
|
||||
) -> Result<opencv::core::Mat, NdCvError> {
|
||||
let shape = input.shape();
|
||||
let strides = input.strides();
|
||||
|
||||
let channels = shape.last().copied().unwrap_or(1);
|
||||
if channels > opencv::core::CV_CN_MAX as usize {
|
||||
Err(Report::new(NdCvError).attach_printable(format!(
|
||||
"Number of channels({channels}) exceeds CV_CN_MAX({}) use the regular version of the function", opencv::core::CV_CN_MAX
|
||||
)))?;
|
||||
}
|
||||
|
||||
if shape.len() > 2 {
|
||||
// Basically the second last stride is used to jump from one column to next
|
||||
// But opencv only keeps ndims - 1 strides so we can't have the column stride as that
|
||||
// will be lost
|
||||
if shape.last() != strides.get(strides.len() - 2).map(|x| *x as usize).as_ref() {
|
||||
Err(Report::new(NdCvError).attach_printable(
|
||||
"You cannot slice into the last axis in ndarray when converting to mat",
|
||||
))?;
|
||||
}
|
||||
} else if shape.len() == 1 {
|
||||
return Err(Report::new(NdCvError).attach_printable(
|
||||
"You cannot convert a 1D array to a Mat while using the consolidated version",
|
||||
));
|
||||
}
|
||||
|
||||
// Since this is the consolidated version we should always only have ndims - 1 sizes and
|
||||
// ndims - 2 strides
|
||||
|
||||
let size_len = shape.len() - 1; // Since we move last axis into the channel
|
||||
let size = shape
|
||||
.iter()
|
||||
.take(size_len)
|
||||
.map(|f| *f as i32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let step_len = strides.len() - 1;
|
||||
let step = strides
|
||||
.iter()
|
||||
.take(step_len)
|
||||
.map(|f| *f as usize * core::mem::size_of::<T>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let data_ptr = input.as_ptr() as *const c_void;
|
||||
|
||||
let typ = opencv::core::CV_MAKETYPE(type_depth::<T>(), channels as i32);
|
||||
|
||||
let mat = opencv::core::Mat::new_nd_with_data_unsafe(
|
||||
size.as_slice(),
|
||||
typ,
|
||||
data_ptr.cast_mut(),
|
||||
Some(step.as_slice()),
|
||||
)
|
||||
.change_context(NdCvError)?;
|
||||
|
||||
Ok(mat)
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn mat_to_ndarray<T: bytemuck::Pod, D: ndarray::Dimension>(
|
||||
mat: &opencv::core::Mat,
|
||||
) -> Result<ndarray::ArrayView<'_, T, D>, NdCvError> {
|
||||
let depth = mat.depth();
|
||||
if type_depth::<T>() != depth {
|
||||
return Err(Report::new(NdCvError).attach_printable(format!(
|
||||
"Expected type Mat<{}> ({}), got Mat<{}> ({})",
|
||||
std::any::type_name::<T>(),
|
||||
type_depth::<T>(),
|
||||
crate::depth_type(depth),
|
||||
depth,
|
||||
)));
|
||||
}
|
||||
|
||||
// Since a dims always returns >= 2 we can't use this to check if it's a 1D array
|
||||
// So we compare the first axis to the total to see if its a 1D array
|
||||
let is_1d = mat.total() as i32 == mat.rows();
|
||||
let dims = is_1d.then_some(1).unwrap_or(mat.dims());
|
||||
let channels = mat.channels();
|
||||
let ndarray_size = (channels != 1).then_some(dims + 1).unwrap_or(dims) as usize;
|
||||
if let Some(ndim) = D::NDIM {
|
||||
// When channels is not 1,
|
||||
// the last dimension is the channels
|
||||
// Array1 -> Mat(ndims = 1, channels = 1)
|
||||
// Array2 -> Mat(ndims = 1, channels = X)
|
||||
// Array2 -> Mat(ndims = 2, channels = 1)
|
||||
// Array3 -> Mat(ndims = 2, channels = X)
|
||||
// Array3 -> Mat(ndims = 3, channels = 1)
|
||||
// ...
|
||||
if ndim != dims as usize && channels == 1 {
|
||||
return Err(Report::new(NdCvError)
|
||||
.attach_printable(format!("Expected {}D array, got {}D", ndim, ndarray_size)));
|
||||
}
|
||||
}
|
||||
|
||||
let mat_size = mat.mat_size();
|
||||
let sizes = (0..dims)
|
||||
.map(|i| mat_size.get(i).change_context(NdCvError))
|
||||
.chain([Ok(channels)])
|
||||
.map(|x| x.map(|x| x as usize))
|
||||
.take(ndarray_size)
|
||||
.collect::<Result<Vec<_>, NdCvError>>()
|
||||
.change_context(NdCvError)?;
|
||||
let strides = (0..(dims - 1))
|
||||
.map(|i| mat.step1(i).change_context(NdCvError))
|
||||
.chain([
|
||||
Ok(channels as usize),
|
||||
Ok((channels == 1).then_some(0).unwrap_or(1)),
|
||||
])
|
||||
.take(ndarray_size)
|
||||
.collect::<Result<Vec<_>, NdCvError>>()
|
||||
.change_context(NdCvError)?;
|
||||
use ndarray::ShapeBuilder;
|
||||
let shape = sizes.strides(strides);
|
||||
let raw_array = ndarray::RawArrayView::from_shape_ptr(shape, mat.data() as *const T)
|
||||
.into_dimensionality()
|
||||
.change_context(NdCvError)?;
|
||||
Ok(unsafe { raw_array.deref_into_view() })
|
||||
}
|
||||
73
ndcv-bridge/src/conversions/matref.rs
Normal file
73
ndcv-bridge/src/conversions/matref.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatRef<'a> {
|
||||
pub(crate) mat: opencv::core::Mat,
|
||||
pub(crate) _marker: core::marker::PhantomData<&'a ()>,
|
||||
}
|
||||
|
||||
impl MatRef<'_> {
|
||||
pub fn clone_pointee(&self) -> opencv::core::Mat {
|
||||
self.mat.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl MatRef<'_> {
|
||||
pub fn new<'a>(mat: opencv::core::Mat) -> MatRef<'a> {
|
||||
MatRef {
|
||||
mat,
|
||||
_marker: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<opencv::core::Mat> for MatRef<'_> {
|
||||
fn as_ref(&self) -> &opencv::core::Mat {
|
||||
&self.mat
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<opencv::core::Mat> for MatRefMut<'_> {
|
||||
fn as_ref(&self) -> &opencv::core::Mat {
|
||||
&self.mat
|
||||
}
|
||||
}
|
||||
|
||||
impl AsMut<opencv::core::Mat> for MatRefMut<'_> {
|
||||
fn as_mut(&mut self) -> &mut opencv::core::Mat {
|
||||
&mut self.mat
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatRefMut<'a> {
|
||||
pub(crate) mat: opencv::core::Mat,
|
||||
pub(crate) _marker: core::marker::PhantomData<&'a mut ()>,
|
||||
}
|
||||
|
||||
impl MatRefMut<'_> {
|
||||
pub fn new<'a>(mat: opencv::core::Mat) -> MatRefMut<'a> {
|
||||
MatRefMut {
|
||||
mat,
|
||||
_marker: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::Deref for MatRef<'_> {
|
||||
type Target = opencv::core::Mat;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.mat
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::Deref for MatRefMut<'_> {
|
||||
type Target = opencv::core::Mat;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.mat
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::DerefMut for MatRefMut<'_> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.mat
|
||||
}
|
||||
}
|
||||
262
ndcv-bridge/src/fir.rs
Normal file
262
ndcv-bridge/src/fir.rs
Normal file
@@ -0,0 +1,262 @@
|
||||
use error_stack::*;
|
||||
use fast_image_resize::*;
|
||||
use images::{Image, ImageRef};
|
||||
#[derive(Debug, Clone, thiserror::Error)]
|
||||
#[error("NdFirError")]
|
||||
pub struct NdFirError;
|
||||
type Result<T, E = Report<NdFirError>> = std::result::Result<T, E>;
|
||||
|
||||
pub trait NdAsImage<T: seal::Sealed, D: ndarray::Dimension>: Sized {
|
||||
fn as_image_ref(&self) -> Result<ImageRef<'_>>;
|
||||
}
|
||||
|
||||
pub trait NdAsImageMut<T: seal::Sealed, D: ndarray::Dimension>: Sized {
|
||||
fn as_image_ref_mut(&mut self) -> Result<Image<'_>>;
|
||||
}
|
||||
|
||||
pub struct NdarrayImageContainer<'a, T: seal::Sealed, D: ndarray::Dimension> {
|
||||
#[allow(dead_code)]
|
||||
data: ndarray::ArrayView<'a, T, D>,
|
||||
pub _phantom: std::marker::PhantomData<(T, D)>,
|
||||
}
|
||||
|
||||
impl<'a, T: seal::Sealed> NdarrayImageContainer<'a, T, ndarray::Ix3> {
|
||||
pub fn new<S: ndarray::Data<Elem = T>>(array: &'a ndarray::ArrayBase<S, ndarray::Ix3>) -> Self {
|
||||
Self {
|
||||
data: array.view(),
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: seal::Sealed> NdarrayImageContainer<'a, T, ndarray::Ix2> {
|
||||
pub fn new<S: ndarray::Data<Elem = T>>(array: &'a ndarray::ArrayBase<S, ndarray::Ix2>) -> Self {
|
||||
Self {
|
||||
data: array.view(),
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
pub struct NdarrayImageContainerMut<'a, T: seal::Sealed, D: ndarray::Dimension> {
|
||||
#[allow(dead_code)]
|
||||
data: ndarray::ArrayViewMut<'a, T, D>,
|
||||
}
|
||||
|
||||
impl<'a, T: seal::Sealed> NdarrayImageContainerMut<'a, T, ndarray::Ix3> {
|
||||
pub fn new<S: ndarray::DataMut<Elem = T>>(
|
||||
array: &'a mut ndarray::ArrayBase<S, ndarray::Ix3>,
|
||||
) -> Self {
|
||||
Self {
|
||||
data: array.view_mut(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: seal::Sealed> NdarrayImageContainerMut<'a, T, ndarray::Ix2> {
|
||||
pub fn new<S: ndarray::DataMut<Elem = T>>(
|
||||
array: &'a mut ndarray::ArrayBase<S, ndarray::Ix2>,
|
||||
) -> Self {
|
||||
Self {
|
||||
data: array.view_mut(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NdarrayImageContainerTyped<'a, T: seal::Sealed, D: ndarray::Dimension, P: PixelTrait> {
|
||||
#[allow(dead_code)]
|
||||
data: ndarray::ArrayView<'a, T, D>,
|
||||
__marker: std::marker::PhantomData<P>,
|
||||
}
|
||||
|
||||
// unsafe impl<'a, T: seal::Sealed + Sync + InnerPixel, P: PixelTrait> ImageView
|
||||
// for NdarrayImageContainerTyped<'a, T, ndarray::Ix3, P>
|
||||
// where
|
||||
// T: bytemuck::Pod,
|
||||
// {
|
||||
// type Pixel = P;
|
||||
// fn width(&self) -> u32 {
|
||||
// self.data.shape()[1] as u32
|
||||
// }
|
||||
// fn height(&self) -> u32 {
|
||||
// self.data.shape()[0] as u32
|
||||
// }
|
||||
// fn iter_rows(&self, start_row: u32) -> impl Iterator<Item = &[Self::Pixel]> {
|
||||
// self.data
|
||||
// .rows()
|
||||
// .into_iter()
|
||||
// .skip(start_row as usize)
|
||||
// .map(|row| {
|
||||
// row.as_slice()
|
||||
// .unwrap_or_default()
|
||||
// .chunks_exact(P::CHANNELS as usize)
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl<'a, T: fast_image_resize::pixels::InnerPixel + seal::Sealed, D: ndarray::Dimension>
|
||||
// fast_image_resize::IntoImageView for NdarrayImageContainer<'a, T, D>
|
||||
// {
|
||||
// fn pixel_type(&self) -> Option<PixelType> {
|
||||
// match D::NDIM {
|
||||
// Some(2) => Some(to_pixel_type::<T>(1).expect("Failed to convert to pixel type")),
|
||||
// Some(3) => Some(
|
||||
// to_pixel_type::<T>(self.data.shape()[2]).expect("Failed to convert to pixel type"),
|
||||
// ),
|
||||
// _ => None,
|
||||
// }
|
||||
// }
|
||||
// fn width(&self) -> u32 {
|
||||
// self.data.shape()[1] as u32
|
||||
// }
|
||||
// fn height(&self) -> u32 {
|
||||
// self.data.shape()[0] as u32
|
||||
// }
|
||||
// fn image_view<P: PixelTrait>(&'a self) -> Option<NdarrayImageContainerTyped<'a, T, D, P>> {
|
||||
// Some(NdarrayImageContainerTyped {
|
||||
// data: self.data.view(),
|
||||
// __marker: std::marker::PhantomData,
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
pub fn to_pixel_type<T: seal::Sealed>(u: usize) -> Result<PixelType> {
|
||||
match (core::any::type_name::<T>(), u) {
|
||||
("u8", 1) => Ok(PixelType::U8),
|
||||
("u8", 2) => Ok(PixelType::U8x2),
|
||||
("u8", 3) => Ok(PixelType::U8x3),
|
||||
("u8", 4) => Ok(PixelType::U8x4),
|
||||
("u16", 1) => Ok(PixelType::U16),
|
||||
("i32", 1) => Ok(PixelType::I32),
|
||||
("f32", 1) => Ok(PixelType::F32),
|
||||
("f32", 2) => Ok(PixelType::F32x2),
|
||||
("f32", 3) => Ok(PixelType::F32x3),
|
||||
("f32", 4) => Ok(PixelType::F32x4),
|
||||
_ => Err(Report::new(NdFirError).attach_printable("Unsupported pixel type")),
|
||||
}
|
||||
}
|
||||
|
||||
mod seal {
|
||||
pub trait Sealed {}
|
||||
impl Sealed for u8 {}
|
||||
impl Sealed for u16 {}
|
||||
impl Sealed for i32 {}
|
||||
impl Sealed for f32 {}
|
||||
}
|
||||
|
||||
impl<S: ndarray::Data<Elem = T>, T: seal::Sealed + bytemuck::Pod, D: ndarray::Dimension>
|
||||
NdAsImage<T, D> for ndarray::ArrayBase<S, D>
|
||||
{
|
||||
/// Clones self and makes a new image
|
||||
fn as_image_ref(&self) -> Result<ImageRef> {
|
||||
let shape = self.shape();
|
||||
let rows = *shape
|
||||
.first()
|
||||
.ok_or_else(|| Report::new(NdFirError).attach_printable("Failed to get rows"))?
|
||||
as u32;
|
||||
let cols = *shape.get(1).unwrap_or(&1) as u32;
|
||||
let channels = *shape.get(2).unwrap_or(&1);
|
||||
let data = self
|
||||
.as_slice()
|
||||
.ok_or(NdFirError)
|
||||
.attach_printable("The ndarray is non continuous")?;
|
||||
let data_bytes: &[u8] = bytemuck::cast_slice(data);
|
||||
|
||||
let pixel_type = to_pixel_type::<T>(channels)?;
|
||||
ImageRef::new(cols, rows, data_bytes, pixel_type)
|
||||
.change_context(NdFirError)
|
||||
.attach_printable("Failed to create Image from ndarray")
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ndarray::DataMut<Elem = T>, T: seal::Sealed + bytemuck::Pod, D: ndarray::Dimension>
|
||||
NdAsImageMut<T, D> for ndarray::ArrayBase<S, D>
|
||||
{
|
||||
fn as_image_ref_mut(&mut self) -> Result<Image<'_>>
|
||||
where
|
||||
S: ndarray::DataMut<Elem = T>,
|
||||
{
|
||||
let shape = self.shape();
|
||||
let rows = *shape
|
||||
.first()
|
||||
.ok_or_else(|| Report::new(NdFirError).attach_printable("Failed to get rows"))?
|
||||
as u32;
|
||||
let cols = *shape.get(1).unwrap_or(&1) as u32;
|
||||
let channels = *shape.get(2).unwrap_or(&1);
|
||||
let data = self
|
||||
.as_slice_mut()
|
||||
.ok_or(NdFirError)
|
||||
.attach_printable("The ndarray is non continuous")?;
|
||||
let data_bytes: &mut [u8] = bytemuck::cast_slice_mut(data);
|
||||
|
||||
let pixel_type = to_pixel_type::<T>(channels)?;
|
||||
Image::from_slice_u8(cols, rows, data_bytes, pixel_type)
|
||||
.change_context(NdFirError)
|
||||
.attach_printable("Failed to create Image from ndarray")
|
||||
}
|
||||
}
|
||||
|
||||
pub trait NdFir<T, D> {
|
||||
fn fast_resize<'o>(
|
||||
&self,
|
||||
height: usize,
|
||||
width: usize,
|
||||
options: impl Into<Option<&'o ResizeOptions>>,
|
||||
) -> Result<ndarray::Array<T, D>>;
|
||||
}
|
||||
|
||||
impl<T: seal::Sealed + bytemuck::Pod + num::Zero, S: ndarray::Data<Elem = T>> NdFir<T, ndarray::Ix3>
|
||||
for ndarray::ArrayBase<S, ndarray::Ix3>
|
||||
{
|
||||
fn fast_resize<'o>(
|
||||
&self,
|
||||
height: usize,
|
||||
width: usize,
|
||||
options: impl Into<Option<&'o ResizeOptions>>,
|
||||
) -> Result<ndarray::Array3<T>> {
|
||||
let source = self.as_image_ref()?;
|
||||
let (_height, _width, channels) = self.dim();
|
||||
let mut dest = ndarray::Array3::<T>::zeros((height, width, channels));
|
||||
let mut dest_image = dest.as_image_ref_mut()?;
|
||||
let mut resizer = fast_image_resize::Resizer::default();
|
||||
resizer
|
||||
.resize(&source, &mut dest_image, options)
|
||||
.change_context(NdFirError)?;
|
||||
Ok(dest)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: seal::Sealed + bytemuck::Pod + num::Zero, S: ndarray::Data<Elem = T>> NdFir<T, ndarray::Ix2>
|
||||
for ndarray::ArrayBase<S, ndarray::Ix2>
|
||||
{
|
||||
fn fast_resize<'o>(
|
||||
&self,
|
||||
height: usize,
|
||||
width: usize,
|
||||
options: impl Into<Option<&'o ResizeOptions>>,
|
||||
) -> Result<ndarray::Array<T, ndarray::Ix2>> {
|
||||
let source = self.as_image_ref()?;
|
||||
let (_height, _width) = self.dim();
|
||||
let mut dest = ndarray::Array::<T, ndarray::Ix2>::zeros((height, width));
|
||||
let mut dest_image = dest.as_image_ref_mut()?;
|
||||
let mut resizer = fast_image_resize::Resizer::default();
|
||||
resizer
|
||||
.resize(&source, &mut dest_image, options)
|
||||
.change_context(NdFirError)?;
|
||||
Ok(dest)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_ndarray_fast_image_resize_u8() {
|
||||
let source_fhd = ndarray::Array3::<u8>::ones((1920, 1080, 3));
|
||||
let mut resized_hd = ndarray::Array3::<u8>::zeros((1280, 720, 3));
|
||||
let mut resizer = fast_image_resize::Resizer::default();
|
||||
resizer
|
||||
.resize(
|
||||
&source_fhd.as_image_ref().unwrap(),
|
||||
&mut resized_hd.as_image_ref_mut().unwrap(),
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(resized_hd.shape(), [1280, 720, 3]);
|
||||
}
|
||||
307
ndcv-bridge/src/gaussian.rs
Normal file
307
ndcv-bridge/src/gaussian.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
//! <https://docs.rs/opencv/latest/opencv/imgproc/fn.gaussian_blur.html>
|
||||
use crate::conversions::*;
|
||||
use crate::prelude_::*;
|
||||
use ndarray::*;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Default, Debug, Copy, Clone)]
|
||||
pub enum BorderType {
|
||||
#[default]
|
||||
BorderConstant = 0,
|
||||
BorderReplicate = 1,
|
||||
BorderReflect = 2,
|
||||
BorderWrap = 3,
|
||||
BorderReflect101 = 4,
|
||||
BorderTransparent = 5,
|
||||
BorderIsolated = 16,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Default, Debug, Copy, Clone)]
|
||||
pub enum AlgorithmHint {
|
||||
#[default]
|
||||
AlgoHintDefault = 0,
|
||||
AlgoHintAccurate = 1,
|
||||
AlgoHintApprox = 2,
|
||||
}
|
||||
|
||||
mod seal {
|
||||
pub trait Sealed {}
|
||||
// src: input image; the image can have any number of channels, which are processed independently, but the depth should be CV_8U, CV_16U, CV_16S, CV_32F or CV_64F.
|
||||
impl Sealed for u8 {}
|
||||
impl Sealed for u16 {}
|
||||
impl Sealed for i16 {}
|
||||
impl Sealed for f32 {}
|
||||
impl Sealed for f64 {}
|
||||
}
|
||||
|
||||
pub trait NdCvGaussianBlur<T: bytemuck::Pod + seal::Sealed, D: ndarray::Dimension>:
|
||||
crate::image::NdImage + crate::conversions::NdAsImage<T, D>
|
||||
{
|
||||
fn gaussian_blur(
|
||||
&self,
|
||||
kernel_size: (u8, u8),
|
||||
sigma_x: f64,
|
||||
sigma_y: f64,
|
||||
border_type: BorderType,
|
||||
) -> Result<ndarray::Array<T, D>, NdCvError>;
|
||||
fn gaussian_blur_def(
|
||||
&self,
|
||||
kernel: (u8, u8),
|
||||
sigma_x: f64,
|
||||
) -> Result<ndarray::Array<T, D>, NdCvError> {
|
||||
self.gaussian_blur(kernel, sigma_x, sigma_x, BorderType::BorderConstant)
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
T: bytemuck::Pod + num::Zero + seal::Sealed,
|
||||
S: ndarray::RawData + ndarray::Data<Elem = T>,
|
||||
D: ndarray::Dimension,
|
||||
> NdCvGaussianBlur<T, D> for ArrayBase<S, D>
|
||||
where
|
||||
ndarray::ArrayBase<S, D>: crate::image::NdImage + crate::conversions::NdAsImage<T, D>,
|
||||
ndarray::Array<T, D>: crate::conversions::NdAsImageMut<T, D>,
|
||||
{
|
||||
fn gaussian_blur(
|
||||
&self,
|
||||
kernel_size: (u8, u8),
|
||||
sigma_x: f64,
|
||||
sigma_y: f64,
|
||||
border_type: BorderType,
|
||||
) -> Result<ndarray::Array<T, D>, NdCvError> {
|
||||
let mut dst = ndarray::Array::zeros(self.dim());
|
||||
let cv_self = self.as_image_mat()?;
|
||||
let mut cv_dst = dst.as_image_mat_mut()?;
|
||||
opencv::imgproc::gaussian_blur(
|
||||
&*cv_self,
|
||||
&mut *cv_dst,
|
||||
opencv::core::Size {
|
||||
width: kernel_size.0 as i32,
|
||||
height: kernel_size.1 as i32,
|
||||
},
|
||||
sigma_x,
|
||||
sigma_y,
|
||||
border_type as i32,
|
||||
)
|
||||
.change_context(NdCvError)
|
||||
.attach_printable("Failed to apply gaussian blur")?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
// impl<
|
||||
// T: bytemuck::Pod + num::Zero + seal::Sealed,
|
||||
// S: ndarray::RawData + ndarray::Data<Elem = T>,
|
||||
// > NdCvGaussianBlur<T, Ix3> for ArrayBase<S, Ix3>
|
||||
// {
|
||||
// fn gaussian_blur(
|
||||
// &self,
|
||||
// kernel_size: (u8, u8),
|
||||
// sigma_x: f64,
|
||||
// sigma_y: f64,
|
||||
// border_type: BorderType,
|
||||
// ) -> Result<ndarray::Array<T, Ix3>, NdCvError> {
|
||||
// let mut dst = ndarray::Array::zeros(self.dim());
|
||||
// let cv_self = self.as_image_mat()?;
|
||||
// let mut cv_dst = dst.as_image_mat_mut()?;
|
||||
// opencv::imgproc::gaussian_blur(
|
||||
// &*cv_self,
|
||||
// &mut *cv_dst,
|
||||
// opencv::core::Size {
|
||||
// width: kernel_size.0 as i32,
|
||||
// height: kernel_size.1 as i32,
|
||||
// },
|
||||
// sigma_x,
|
||||
// sigma_y,
|
||||
// border_type as i32,
|
||||
// )
|
||||
// .change_context(NdCvError)
|
||||
// .attach_printable("Failed to apply gaussian blur")?;
|
||||
// Ok(dst)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// impl<
|
||||
// T: bytemuck::Pod + num::Zero + seal::Sealed,
|
||||
// S: ndarray::RawData + ndarray::Data<Elem = T>,
|
||||
// > NdCvGaussianBlur<T, Ix2> for ArrayBase<S, Ix2>
|
||||
// {
|
||||
// fn gaussian_blur(
|
||||
// &self,
|
||||
// kernel_size: (u8, u8),
|
||||
// sigma_x: f64,
|
||||
// sigma_y: f64,
|
||||
// border_type: BorderType,
|
||||
// ) -> Result<ndarray::Array<T, Ix2>, NdCvError> {
|
||||
// let mut dst = ndarray::Array::zeros(self.dim());
|
||||
// let cv_self = self.as_image_mat()?;
|
||||
// let mut cv_dst = dst.as_image_mat_mut()?;
|
||||
// opencv::imgproc::gaussian_blur(
|
||||
// &*cv_self,
|
||||
// &mut *cv_dst,
|
||||
// opencv::core::Size {
|
||||
// width: kernel_size.0 as i32,
|
||||
// height: kernel_size.1 as i32,
|
||||
// },
|
||||
// sigma_x,
|
||||
// sigma_y,
|
||||
// border_type as i32,
|
||||
// )
|
||||
// .change_context(NdCvError)
|
||||
// .attach_printable("Failed to apply gaussian blur")?;
|
||||
// Ok(dst)
|
||||
// }
|
||||
// }
|
||||
|
||||
/// For smaller values it is faster to use the allocated version
|
||||
/// For example in a 4k f32 image this is about 50% faster than the allocated one
|
||||
pub trait NdCvGaussianBlurInPlace<T: bytemuck::Pod + seal::Sealed, D: ndarray::Dimension>:
|
||||
crate::image::NdImage + crate::conversions::NdAsImageMut<T, D>
|
||||
{
|
||||
fn gaussian_blur_inplace(
|
||||
&mut self,
|
||||
kernel_size: (u8, u8),
|
||||
sigma_x: f64,
|
||||
sigma_y: f64,
|
||||
border_type: BorderType,
|
||||
) -> Result<&mut Self, NdCvError>;
|
||||
fn gaussian_blur_def_inplace(
|
||||
&mut self,
|
||||
kernel: (u8, u8),
|
||||
sigma_x: f64,
|
||||
) -> Result<&mut Self, NdCvError> {
|
||||
self.gaussian_blur_inplace(kernel, sigma_x, sigma_x, BorderType::BorderConstant)
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
T: bytemuck::Pod + num::Zero + seal::Sealed,
|
||||
S: ndarray::RawData + ndarray::DataMut<Elem = T>,
|
||||
D: ndarray::Dimension,
|
||||
> NdCvGaussianBlurInPlace<T, D> for ArrayBase<S, D>
|
||||
where
|
||||
Self: crate::image::NdImage + crate::conversions::NdAsImageMut<T, D>,
|
||||
{
|
||||
fn gaussian_blur_inplace(
|
||||
&mut self,
|
||||
kernel_size: (u8, u8),
|
||||
sigma_x: f64,
|
||||
sigma_y: f64,
|
||||
border_type: BorderType,
|
||||
) -> Result<&mut Self, NdCvError> {
|
||||
let mut cv_self = self.as_image_mat_mut()?;
|
||||
|
||||
unsafe {
|
||||
crate::inplace::op_inplace(&mut *cv_self, |this, out| {
|
||||
opencv::imgproc::gaussian_blur(
|
||||
this,
|
||||
out,
|
||||
opencv::core::Size {
|
||||
width: kernel_size.0 as i32,
|
||||
height: kernel_size.1 as i32,
|
||||
},
|
||||
sigma_x,
|
||||
sigma_y,
|
||||
border_type as i32,
|
||||
)
|
||||
})
|
||||
}
|
||||
.change_context(NdCvError)
|
||||
.attach_printable("Failed to apply gaussian blur")?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ndarray::Array3;
|
||||
|
||||
#[test]
|
||||
fn test_gaussian_basic() {
|
||||
let arr = Array3::<u8>::ones((10, 10, 3));
|
||||
let kernel_size = (3, 3);
|
||||
let sigma_x = 0.0;
|
||||
let sigma_y = 0.0;
|
||||
let border_type = BorderType::BorderConstant;
|
||||
let res = arr
|
||||
.gaussian_blur(kernel_size, sigma_x, sigma_y, border_type)
|
||||
.unwrap();
|
||||
assert_eq!(res.shape(), &[10, 10, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gaussian_edge_preservation() {
|
||||
// Create an image with a sharp edge
|
||||
let mut arr = Array3::<u8>::zeros((10, 10, 3));
|
||||
arr.slice_mut(s![..5, .., ..]).fill(255); // Top half white, bottom half black
|
||||
|
||||
let res = arr
|
||||
.gaussian_blur((3, 3), 1.0, 1.0, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
|
||||
// Check that the middle row (edge) has intermediate values
|
||||
let middle_row = res.slice(s![4..6, 5, 0]);
|
||||
assert!(middle_row.iter().all(|&x| x > 0 && x < 255));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gaussian_different_kernel_sizes() {
|
||||
let arr = Array3::<u8>::ones((20, 20, 3));
|
||||
|
||||
// Test different kernel sizes
|
||||
let kernel_sizes = [(3, 3), (5, 5), (7, 7)];
|
||||
for &kernel_size in &kernel_sizes {
|
||||
let res = arr
|
||||
.gaussian_blur(kernel_size, 1.0, 1.0, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
assert_eq!(res.shape(), &[20, 20, 3]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gaussian_different_border_types() {
|
||||
let mut arr = Array3::<u8>::zeros((10, 10, 3));
|
||||
arr.slice_mut(s![4..7, 4..7, ..]).fill(255); // White square in center
|
||||
|
||||
let border_types = [
|
||||
BorderType::BorderConstant,
|
||||
BorderType::BorderReplicate,
|
||||
BorderType::BorderReflect,
|
||||
BorderType::BorderReflect101,
|
||||
];
|
||||
|
||||
for border_type in border_types {
|
||||
let res = arr.gaussian_blur((3, 3), 1.0, 1.0, border_type).unwrap();
|
||||
assert_eq!(res.shape(), &[10, 10, 3]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gaussian_different_types() {
|
||||
// Test with different numeric types
|
||||
let arr_u8 = Array3::<u8>::ones((10, 10, 3));
|
||||
let arr_f32 = Array3::<f32>::ones((10, 10, 3));
|
||||
|
||||
let res_u8 = arr_u8
|
||||
.gaussian_blur((3, 3), 1.0, 1.0, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
let res_f32 = arr_f32
|
||||
.gaussian_blur((3, 3), 1.0, 1.0, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res_u8.shape(), &[10, 10, 3]);
|
||||
assert_eq!(res_f32.shape(), &[10, 10, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_gaussian_invalid_kernel_size() {
|
||||
let arr = Array3::<u8>::ones((10, 10, 3));
|
||||
// Even kernel sizes should fail
|
||||
let _ = arr
|
||||
.gaussian_blur((2, 2), 1.0, 1.0, BorderType::BorderConstant)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
30
ndcv-bridge/src/image.rs
Normal file
30
ndcv-bridge/src/image.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use ndarray::*;
|
||||
pub trait NdImage {
|
||||
fn width(&self) -> usize;
|
||||
fn height(&self) -> usize;
|
||||
fn channels(&self) -> usize;
|
||||
}
|
||||
|
||||
impl<T, S: RawData<Elem = T>> NdImage for ArrayBase<S, Ix3> {
|
||||
fn width(&self) -> usize {
|
||||
self.dim().1
|
||||
}
|
||||
fn height(&self) -> usize {
|
||||
self.dim().0
|
||||
}
|
||||
fn channels(&self) -> usize {
|
||||
self.dim().2
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S: RawData<Elem = T>> NdImage for ArrayBase<S, Ix2> {
|
||||
fn width(&self) -> usize {
|
||||
self.dim().1
|
||||
}
|
||||
fn height(&self) -> usize {
|
||||
self.dim().0
|
||||
}
|
||||
fn channels(&self) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
14
ndcv-bridge/src/inplace.rs
Normal file
14
ndcv-bridge/src/inplace.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use opencv::core::Mat;
|
||||
use opencv::prelude::*;
|
||||
use opencv::Result;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn op_inplace<T>(
|
||||
m: &mut Mat,
|
||||
f: impl FnOnce(&Mat, &mut Mat) -> Result<T>,
|
||||
) -> Result<T> {
|
||||
let mut m_alias = Mat::from_raw(m.as_raw_mut());
|
||||
let out = f(m, &mut m_alias);
|
||||
let _ = m_alias.into_raw();
|
||||
out
|
||||
}
|
||||
83
ndcv-bridge/src/lib.rs
Normal file
83
ndcv-bridge/src/lib.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
//! Methods and type conversions for ndarray to opencv and vice versa
|
||||
mod blend;
|
||||
// mod dilate;
|
||||
pub mod fir;
|
||||
mod image;
|
||||
mod inplace;
|
||||
pub mod percentile;
|
||||
mod roi;
|
||||
|
||||
#[cfg(feature = "opencv")]
|
||||
pub mod bounding_rect;
|
||||
// #[cfg(feature = "opencv")]
|
||||
// pub mod color_space;
|
||||
#[cfg(feature = "opencv")]
|
||||
pub mod connected_components;
|
||||
#[cfg(feature = "opencv")]
|
||||
pub mod contours;
|
||||
#[cfg(feature = "opencv")]
|
||||
pub mod conversions;
|
||||
// #[cfg(feature = "opencv")]
|
||||
// pub mod gaussian;
|
||||
#[cfg(feature = "opencv")]
|
||||
pub mod resize;
|
||||
|
||||
pub mod codec;
|
||||
pub mod orient;
|
||||
pub use blend::NdBlend;
|
||||
pub use fast_image_resize::{FilterType, ResizeAlg, ResizeOptions, Resizer};
|
||||
pub use fir::NdFir;
|
||||
// pub use gaussian::{BorderType, NdCvGaussianBlur, NdCvGaussianBlurInPlace};
|
||||
pub use roi::{NdRoi, NdRoiMut, NdRoiZeroPadded};
|
||||
|
||||
#[cfg(feature = "opencv")]
|
||||
pub use contours::{
|
||||
ContourApproximationMethod, ContourHierarchy, ContourResult, ContourRetrievalMode,
|
||||
NdCvContourArea, NdCvFindContours,
|
||||
};
|
||||
|
||||
#[cfg(feature = "opencv")]
|
||||
pub use bounding_rect::BoundingRect;
|
||||
#[cfg(feature = "opencv")]
|
||||
pub use connected_components::{Connectivity, NdCvConnectedComponents};
|
||||
#[cfg(feature = "opencv")]
|
||||
pub use conversions::{MatAsNd, NdAsImage, NdAsImageMut, NdAsMat, NdAsMatMut, NdCvConversion};
|
||||
#[cfg(feature = "opencv")]
|
||||
pub use resize::{Interpolation, NdCvResize};
|
||||
|
||||
pub(crate) mod prelude_ {
|
||||
pub use crate::NdCvError;
|
||||
pub use error_stack::*;
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("NdCvError")]
|
||||
pub struct NdCvError;
|
||||
|
||||
#[cfg(feature = "opencv")]
|
||||
pub fn type_depth<T>() -> i32 {
|
||||
match std::any::type_name::<T>() {
|
||||
"u8" => opencv::core::CV_8U,
|
||||
"i8" => opencv::core::CV_8S,
|
||||
"u16" => opencv::core::CV_16U,
|
||||
"i16" => opencv::core::CV_16S,
|
||||
"i32" => opencv::core::CV_32S,
|
||||
"f32" => opencv::core::CV_32F,
|
||||
"f64" => opencv::core::CV_64F,
|
||||
_ => panic!("Unsupported type"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "opencv")]
|
||||
pub fn depth_type(depth: i32) -> &'static str {
|
||||
match depth {
|
||||
opencv::core::CV_8U => "u8",
|
||||
opencv::core::CV_8S => "i8",
|
||||
opencv::core::CV_16U => "u16",
|
||||
opencv::core::CV_16S => "i16",
|
||||
opencv::core::CV_32S => "i32",
|
||||
opencv::core::CV_32F => "f32",
|
||||
opencv::core::CV_64F => "f64",
|
||||
_ => panic!("Unsupported depth"),
|
||||
}
|
||||
}
|
||||
188
ndcv-bridge/src/orient.rs
Normal file
188
ndcv-bridge/src/orient.rs
Normal file
@@ -0,0 +1,188 @@
|
||||
use ndarray::{Array, ArrayBase, ArrayView};
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum Orientation {
|
||||
NoRotation,
|
||||
Mirror,
|
||||
Clock180,
|
||||
Water,
|
||||
MirrorClock270,
|
||||
Clock90,
|
||||
MirrorClock90,
|
||||
Clock270,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl Orientation {
|
||||
pub fn inverse(&self) -> Self {
|
||||
match self {
|
||||
Self::Clock90 => Self::Clock270,
|
||||
Self::Clock270 => Self::Clock90,
|
||||
_ => *self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Orientation {
|
||||
pub fn from_raw(flip: u8) -> Self {
|
||||
match flip {
|
||||
1 => Orientation::NoRotation,
|
||||
2 => Orientation::Mirror,
|
||||
3 => Orientation::Clock180,
|
||||
4 => Orientation::Water,
|
||||
5 => Orientation::MirrorClock270,
|
||||
6 => Orientation::Clock90,
|
||||
7 => Orientation::MirrorClock90,
|
||||
8 => Orientation::Clock270,
|
||||
_ => Orientation::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum RotationFlag {
|
||||
Clock90,
|
||||
Clock180,
|
||||
Clock270,
|
||||
}
|
||||
|
||||
impl RotationFlag {
|
||||
pub fn neg(&self) -> Self {
|
||||
match self {
|
||||
RotationFlag::Clock90 => RotationFlag::Clock270,
|
||||
RotationFlag::Clock180 => RotationFlag::Clock180,
|
||||
RotationFlag::Clock270 => RotationFlag::Clock90,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_orientation(&self) -> Orientation {
|
||||
match self {
|
||||
RotationFlag::Clock90 => Orientation::Clock90,
|
||||
RotationFlag::Clock180 => Orientation::Clock180,
|
||||
RotationFlag::Clock270 => Orientation::Clock270,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum FlipFlag {
|
||||
Mirror,
|
||||
Water,
|
||||
Both,
|
||||
}
|
||||
|
||||
pub trait Orient<T: bytemuck::Pod, D: ndarray::Dimension> {
|
||||
fn flip(&self, flip: FlipFlag) -> Array<T, D>;
|
||||
fn rotate(&self, rotation: RotationFlag) -> Array<T, D>;
|
||||
fn owned(&self) -> Array<T, D>;
|
||||
|
||||
fn unorient(&self, orientation: Orientation) -> Array<T, D>
|
||||
where
|
||||
Array<T, D>: Orient<T, D>,
|
||||
Self: ToOwned<Owned = Array<T, D>>,
|
||||
{
|
||||
let inverse_orientation = orientation.inverse();
|
||||
self.orient(inverse_orientation)
|
||||
|
||||
// match orientation {
|
||||
// Orientation::NoRotation | Orientation::Unknown => self.to_owned(),
|
||||
// Orientation::Mirror => self.flip(FlipFlag::Mirror).to_owned(),
|
||||
// Orientation::Clock180 => self.rotate(RotationFlag::Clock180),
|
||||
// Orientation::Water => self.flip(FlipFlag::Water).to_owned(),
|
||||
// Orientation::MirrorClock270 => self
|
||||
// .rotate(RotationFlag::Clock90)
|
||||
// .flip(FlipFlag::Mirror)
|
||||
// .to_owned(),
|
||||
// Orientation::Clock90 => self.rotate(RotationFlag::Clock270),
|
||||
// Orientation::MirrorClock90 => self
|
||||
// .rotate(RotationFlag::Clock270)
|
||||
// .flip(FlipFlag::Mirror)
|
||||
// .to_owned(),
|
||||
// Orientation::Clock270 => self.rotate(RotationFlag::Clock90),
|
||||
// }
|
||||
}
|
||||
|
||||
fn orient(&self, orientation: Orientation) -> Array<T, D>
|
||||
where
|
||||
Array<T, D>: Orient<T, D>,
|
||||
{
|
||||
match orientation {
|
||||
Orientation::NoRotation | Orientation::Unknown => self.owned(),
|
||||
Orientation::Mirror => self.flip(FlipFlag::Mirror).to_owned(),
|
||||
Orientation::Clock180 => self.rotate(RotationFlag::Clock180),
|
||||
Orientation::Water => self.flip(FlipFlag::Water).to_owned(),
|
||||
Orientation::MirrorClock270 => self
|
||||
.flip(FlipFlag::Mirror)
|
||||
.rotate(RotationFlag::Clock270)
|
||||
.to_owned(),
|
||||
Orientation::Clock90 => self.rotate(RotationFlag::Clock90),
|
||||
Orientation::MirrorClock90 => self
|
||||
.flip(FlipFlag::Mirror)
|
||||
.rotate(RotationFlag::Clock90)
|
||||
.to_owned(),
|
||||
Orientation::Clock270 => self.rotate(RotationFlag::Clock270),
|
||||
}
|
||||
.as_standard_layout()
|
||||
.to_owned()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod + Copy, S: ndarray::Data<Elem = T>> Orient<T, ndarray::Ix3>
|
||||
for ArrayBase<S, ndarray::Ix3>
|
||||
{
|
||||
fn flip(&self, flip: FlipFlag) -> Array<T, ndarray::Ix3> {
|
||||
match flip {
|
||||
FlipFlag::Mirror => self.slice(ndarray::s![.., ..;-1, ..]),
|
||||
FlipFlag::Water => self.slice(ndarray::s![..;-1, .., ..]),
|
||||
FlipFlag::Both => self.slice(ndarray::s![..;-1, ..;-1, ..]),
|
||||
}
|
||||
.as_standard_layout()
|
||||
.to_owned()
|
||||
}
|
||||
|
||||
fn owned(&self) -> Array<T, ndarray::Ix3> {
|
||||
self.to_owned()
|
||||
}
|
||||
|
||||
fn rotate(&self, rotation: RotationFlag) -> Array<T, ndarray::Ix3> {
|
||||
match rotation {
|
||||
RotationFlag::Clock90 => self
|
||||
.view()
|
||||
.permuted_axes([1, 0, 2])
|
||||
.flip(FlipFlag::Mirror)
|
||||
.to_owned(),
|
||||
RotationFlag::Clock180 => self.flip(FlipFlag::Both).to_owned(),
|
||||
RotationFlag::Clock270 => self
|
||||
.view()
|
||||
.permuted_axes([1, 0, 2])
|
||||
.flip(FlipFlag::Water)
|
||||
.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod + Copy, S: ndarray::Data<Elem = T>> Orient<T, ndarray::Ix2>
|
||||
for ArrayBase<S, ndarray::Ix2>
|
||||
{
|
||||
fn flip(&self, flip: FlipFlag) -> Array<T, ndarray::Ix2> {
|
||||
match flip {
|
||||
FlipFlag::Mirror => self.slice(ndarray::s![.., ..;-1,]),
|
||||
FlipFlag::Water => self.slice(ndarray::s![..;-1, ..,]),
|
||||
FlipFlag::Both => self.slice(ndarray::s![..;-1, ..;-1,]),
|
||||
}
|
||||
.as_standard_layout()
|
||||
.to_owned()
|
||||
}
|
||||
|
||||
fn owned(&self) -> Array<T, ndarray::Ix2> {
|
||||
self.to_owned()
|
||||
}
|
||||
|
||||
fn rotate(&self, rotation: RotationFlag) -> Array<T, ndarray::Ix2> {
|
||||
match rotation {
|
||||
RotationFlag::Clock90 => self.t().flip(FlipFlag::Mirror).to_owned(),
|
||||
RotationFlag::Clock180 => self.flip(FlipFlag::Both).to_owned(),
|
||||
RotationFlag::Clock270 => self.t().flip(FlipFlag::Water).to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
63
ndcv-bridge/src/percentile.rs
Normal file
63
ndcv-bridge/src/percentile.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
use error_stack::*;
|
||||
use ndarray::{ArrayBase, Ix1};
|
||||
use num::cast::AsPrimitive;
|
||||
|
||||
use crate::NdCvError;
|
||||
|
||||
pub trait Percentile {
|
||||
fn percentile(&self, qth_percentile: f64) -> Result<f64, NdCvError>;
|
||||
}
|
||||
|
||||
impl<T: std::cmp::Ord + Clone + AsPrimitive<f64>, S: ndarray::Data<Elem = T>> Percentile
|
||||
for ArrayBase<S, Ix1>
|
||||
{
|
||||
fn percentile(&self, qth_percentile: f64) -> Result<f64, NdCvError> {
|
||||
if self.len() == 0 {
|
||||
return Err(error_stack::Report::new(NdCvError).attach_printable("Empty Input"));
|
||||
}
|
||||
|
||||
if !(0_f64..1_f64).contains(&qth_percentile) {
|
||||
return Err(error_stack::Report::new(NdCvError)
|
||||
.attach_printable("Qth percentile must be between 0 and 1"));
|
||||
}
|
||||
|
||||
let mut standard_array = self.as_standard_layout();
|
||||
let mut raw_data = standard_array
|
||||
.as_slice_mut()
|
||||
.expect("An array in standard layout will always return its inner slice");
|
||||
|
||||
raw_data.sort();
|
||||
|
||||
let actual_index = qth_percentile * (raw_data.len() - 1) as f64;
|
||||
|
||||
let lower_index = (actual_index.floor() as usize).clamp(0, raw_data.len() - 1);
|
||||
let upper_index = (actual_index.ceil() as usize).clamp(0, raw_data.len() - 1);
|
||||
|
||||
if lower_index == upper_index {
|
||||
Ok(raw_data[lower_index].as_())
|
||||
} else {
|
||||
let weight = actual_index - lower_index as f64;
|
||||
Ok(raw_data[lower_index].as_() * (1.0 - weight) + raw_data[upper_index].as_() * weight)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fn percentile(data: &Array1<f64>, p: f64) -> f64 {
|
||||
// if data.len() == 0 {
|
||||
// return 0.0;
|
||||
// }
|
||||
//
|
||||
// let mut sorted_data = data.to_vec();
|
||||
// sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
//
|
||||
// let index = (p / 100.0) * (sorted_data.len() - 1) as f64;
|
||||
// let lower = index.floor() as usize;
|
||||
// let upper = index.ceil() as usize;
|
||||
//
|
||||
// if lower == upper {
|
||||
// sorted_data[lower] as f64
|
||||
// } else {
|
||||
// let weight = index - lower as f64;
|
||||
// sorted_data[lower] as f64 * (1.0 - weight) + sorted_data[upper] as f64 * weight
|
||||
// }
|
||||
// }
|
||||
108
ndcv-bridge/src/resize.rs
Normal file
108
ndcv-bridge/src/resize.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use crate::{prelude_::*, NdAsImage, NdAsImageMut};
|
||||
|
||||
/// Resize ndarray using OpenCV resize functions
|
||||
pub trait NdCvResize<T, D>: seal::SealedInternal {
|
||||
/// The input array must be a continuous 2D or 3D ndarray
|
||||
fn resize(
|
||||
&self,
|
||||
height: u16,
|
||||
width: u16,
|
||||
interpolation: Interpolation,
|
||||
) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<T>, D>, NdCvError>;
|
||||
}
|
||||
|
||||
#[repr(i32)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub enum Interpolation {
|
||||
Linear = opencv::imgproc::INTER_LINEAR,
|
||||
LinearExact = opencv::imgproc::INTER_LINEAR_EXACT,
|
||||
Max = opencv::imgproc::INTER_MAX,
|
||||
Area = opencv::imgproc::INTER_AREA,
|
||||
Cubic = opencv::imgproc::INTER_CUBIC,
|
||||
Nearest = opencv::imgproc::INTER_NEAREST,
|
||||
NearestExact = opencv::imgproc::INTER_NEAREST_EXACT,
|
||||
Lanczos4 = opencv::imgproc::INTER_LANCZOS4,
|
||||
}
|
||||
|
||||
mod seal {
|
||||
pub trait SealedInternal {}
|
||||
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>> SealedInternal
|
||||
for ndarray::ArrayBase<S, ndarray::Ix3>
|
||||
{
|
||||
}
|
||||
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>> SealedInternal
|
||||
for ndarray::ArrayBase<S, ndarray::Ix2>
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod + num::Zero, S: ndarray::Data<Elem = T>> NdCvResize<T, ndarray::Ix2>
|
||||
for ndarray::ArrayBase<S, ndarray::Ix2>
|
||||
{
|
||||
fn resize(
|
||||
&self,
|
||||
height: u16,
|
||||
width: u16,
|
||||
interpolation: Interpolation,
|
||||
) -> Result<ndarray::Array2<T>, NdCvError> {
|
||||
let mat = self.as_image_mat()?;
|
||||
let mut dest = ndarray::Array2::zeros((height.into(), width.into()));
|
||||
let mut dest_mat = dest.as_image_mat_mut()?;
|
||||
opencv::imgproc::resize(
|
||||
mat.as_ref(),
|
||||
dest_mat.as_mut(),
|
||||
opencv::core::Size {
|
||||
height: height.into(),
|
||||
width: width.into(),
|
||||
},
|
||||
0.,
|
||||
0.,
|
||||
interpolation as i32,
|
||||
)
|
||||
.change_context(NdCvError)?;
|
||||
Ok(dest)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod + num::Zero, S: ndarray::Data<Elem = T>> NdCvResize<T, ndarray::Ix3>
|
||||
for ndarray::ArrayBase<S, ndarray::Ix3>
|
||||
{
|
||||
fn resize(
|
||||
&self,
|
||||
height: u16,
|
||||
width: u16,
|
||||
interpolation: Interpolation,
|
||||
) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<T>, ndarray::Ix3>, NdCvError> {
|
||||
let mat = self.as_image_mat()?;
|
||||
let mut dest =
|
||||
ndarray::Array3::zeros((height.into(), width.into(), self.len_of(ndarray::Axis(2))));
|
||||
let mut dest_mat = dest.as_image_mat_mut()?;
|
||||
opencv::imgproc::resize(
|
||||
mat.as_ref(),
|
||||
dest_mat.as_mut(),
|
||||
opencv::core::Size {
|
||||
height: height.into(),
|
||||
width: width.into(),
|
||||
},
|
||||
0.,
|
||||
0.,
|
||||
interpolation as i32,
|
||||
)
|
||||
.change_context(NdCvError)?;
|
||||
Ok(dest)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resize_simple() {
|
||||
let foo = ndarray::Array2::<u8>::ones((10, 10));
|
||||
let foo_resized = foo.resize(15, 20, Interpolation::Linear).unwrap();
|
||||
assert_eq!(foo_resized, ndarray::Array2::<u8>::ones((15, 20)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resize_3d() {
|
||||
let foo = ndarray::Array3::<u8>::ones((10, 10, 3));
|
||||
let foo_resized = foo.resize(15, 20, Interpolation::Linear).unwrap();
|
||||
assert_eq!(foo_resized, ndarray::Array3::<u8>::ones((15, 20, 3)));
|
||||
}
|
||||
274
ndcv-bridge/src/roi.rs
Normal file
274
ndcv-bridge/src/roi.rs
Normal file
@@ -0,0 +1,274 @@
|
||||
pub trait NdRoi<T, D>: seal::Sealed {
|
||||
fn roi(&self, rect: bounding_box::Aabb2<usize>) -> ndarray::ArrayView<T, D>;
|
||||
}
|
||||
|
||||
pub trait NdRoiMut<T, D>: seal::Sealed {
|
||||
fn roi_mut(&mut self, rect: bounding_box::Aabb2<usize>) -> ndarray::ArrayViewMut<T, D>;
|
||||
}
|
||||
|
||||
mod seal {
|
||||
use ndarray::{Ix2, Ix3};
|
||||
pub trait Sealed {}
|
||||
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>> Sealed for ndarray::ArrayBase<S, Ix2> {}
|
||||
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>> Sealed for ndarray::ArrayBase<S, Ix3> {}
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>> NdRoi<T, ndarray::Ix3>
|
||||
for ndarray::ArrayBase<S, ndarray::Ix3>
|
||||
{
|
||||
fn roi(&self, rect: bounding_box::Aabb2<usize>) -> ndarray::ArrayView3<T> {
|
||||
let y1 = rect.y1();
|
||||
let y2 = rect.y2();
|
||||
let x1 = rect.x1();
|
||||
let x2 = rect.x2();
|
||||
self.slice(ndarray::s![y1..y2, x1..x2, ..])
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod, S: ndarray::DataMut<Elem = T>> NdRoiMut<T, ndarray::Ix3>
|
||||
for ndarray::ArrayBase<S, ndarray::Ix3>
|
||||
{
|
||||
fn roi_mut(&mut self, rect: bounding_box::Aabb2<usize>) -> ndarray::ArrayViewMut3<T> {
|
||||
let y1 = rect.y1();
|
||||
let y2 = rect.y2();
|
||||
let x1 = rect.x1();
|
||||
let x2 = rect.x2();
|
||||
self.slice_mut(ndarray::s![y1..y2, x1..x2, ..])
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>> NdRoi<T, ndarray::Ix2>
|
||||
for ndarray::ArrayBase<S, ndarray::Ix2>
|
||||
{
|
||||
fn roi(&self, rect: bounding_box::Aabb2<usize>) -> ndarray::ArrayView2<T> {
|
||||
let y1 = rect.y1();
|
||||
let y2 = rect.y2();
|
||||
let x1 = rect.x1();
|
||||
let x2 = rect.x2();
|
||||
self.slice(ndarray::s![y1..y2, x1..x2])
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod, S: ndarray::DataMut<Elem = T>> NdRoiMut<T, ndarray::Ix2>
|
||||
for ndarray::ArrayBase<S, ndarray::Ix2>
|
||||
{
|
||||
fn roi_mut(&mut self, rect: bounding_box::Aabb2<usize>) -> ndarray::ArrayViewMut2<T> {
|
||||
let y1 = rect.y1();
|
||||
let y2 = rect.y2();
|
||||
let x1 = rect.x1();
|
||||
let x2 = rect.x2();
|
||||
self.slice_mut(ndarray::s![y1..y2, x1..x2])
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roi() {
|
||||
let arr = ndarray::Array3::<u8>::zeros((10, 10, 3));
|
||||
let rect = bounding_box::Aabb2::from_xywh(1, 1, 3, 3);
|
||||
let roi = arr.roi(rect);
|
||||
assert_eq!(roi.shape(), &[3, 3, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roi_2d() {
|
||||
let arr = ndarray::Array2::<u8>::zeros((10, 10));
|
||||
let rect = bounding_box::Aabb2::from_xywh(1, 1, 3, 3);
|
||||
let roi = arr.roi(rect);
|
||||
assert_eq!(roi.shape(), &[3, 3]);
|
||||
}
|
||||
|
||||
/// ```text
|
||||
/// ┌──────────────────┐
|
||||
/// │ padded │
|
||||
/// │ ┌────────┐ │
|
||||
/// │ │ │ │
|
||||
/// │ │original│ │
|
||||
/// │ │ │ │
|
||||
/// │ └────────┘ │
|
||||
/// │ zeroed │
|
||||
/// └──────────────────┘
|
||||
/// ```
|
||||
///
|
||||
/// Returns the padded bounding box and the padded segment
|
||||
/// The padded is the padded bounding box
|
||||
/// The original is the original bounding box
|
||||
/// Returns the padded bounding box as zeros and the original bbox as the original segment
|
||||
// Helper functions for missing methods from old bbox crate
|
||||
fn bbox_top_left_usize(bbox: &bounding_box::Aabb2<usize>) -> (usize, usize) {
|
||||
(bbox.x1(), bbox.y1())
|
||||
}
|
||||
|
||||
fn bbox_with_top_left_usize(
|
||||
bbox: &bounding_box::Aabb2<usize>,
|
||||
x: usize,
|
||||
y: usize,
|
||||
) -> bounding_box::Aabb2<usize> {
|
||||
let width = bbox.x2() - bbox.x1();
|
||||
let height = bbox.y2() - bbox.y1();
|
||||
bounding_box::Aabb2::from_xywh(x, y, width, height)
|
||||
}
|
||||
|
||||
fn bbox_with_origin_usize(point: (usize, usize), origin: (usize, usize)) -> (usize, usize) {
|
||||
(point.0 - origin.0, point.1 - origin.1)
|
||||
}
|
||||
|
||||
fn bbox_zeros_ndarray_2d<T: num::Zero + Copy>(
|
||||
bbox: &bounding_box::Aabb2<usize>,
|
||||
) -> ndarray::Array2<T> {
|
||||
let width = bbox.x2() - bbox.x1();
|
||||
let height = bbox.y2() - bbox.y1();
|
||||
ndarray::Array2::<T>::zeros((height, width))
|
||||
}
|
||||
|
||||
fn bbox_zeros_ndarray_3d<T: num::Zero + Copy>(
|
||||
bbox: &bounding_box::Aabb2<usize>,
|
||||
channels: usize,
|
||||
) -> ndarray::Array3<T> {
|
||||
let width = bbox.x2() - bbox.x1();
|
||||
let height = bbox.y2() - bbox.y1();
|
||||
ndarray::Array3::<T>::zeros((height, width, channels))
|
||||
}
|
||||
|
||||
fn bbox_round_f64(bbox: &bounding_box::Aabb2<f64>) -> bounding_box::Aabb2<f64> {
|
||||
let x1 = bbox.x1().round();
|
||||
let y1 = bbox.y1().round();
|
||||
let x2 = bbox.x2().round();
|
||||
let y2 = bbox.y2().round();
|
||||
bounding_box::Aabb2::from_x1y1x2y2(x1, y1, x2, y2)
|
||||
}
|
||||
|
||||
fn bbox_cast_f64_to_usize(bbox: &bounding_box::Aabb2<f64>) -> bounding_box::Aabb2<usize> {
|
||||
let x1 = bbox.x1() as usize;
|
||||
let y1 = bbox.y1() as usize;
|
||||
let x2 = bbox.x2() as usize;
|
||||
let y2 = bbox.y2() as usize;
|
||||
bounding_box::Aabb2::from_x1y1x2y2(x1, y1, x2, y2)
|
||||
}
|
||||
|
||||
pub trait NdRoiZeroPadded<T, D: ndarray::Dimension>: seal::Sealed {
|
||||
fn roi_zero_padded(
|
||||
&self,
|
||||
original: bounding_box::Aabb2<usize>,
|
||||
padded: bounding_box::Aabb2<usize>,
|
||||
) -> (bounding_box::Aabb2<usize>, ndarray::Array<T, D>);
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod + num::Zero> NdRoiZeroPadded<T, ndarray::Ix2> for ndarray::Array2<T> {
|
||||
fn roi_zero_padded(
|
||||
&self,
|
||||
original: bounding_box::Aabb2<usize>,
|
||||
padded: bounding_box::Aabb2<usize>,
|
||||
) -> (bounding_box::Aabb2<usize>, ndarray::Array2<T>) {
|
||||
// The co-ordinates of both the original and the padded bounding boxes must be contained in
|
||||
// self or it will panic
|
||||
|
||||
let self_bbox = bounding_box::Aabb2::from_xywh(0, 0, self.shape()[1], self.shape()[0]);
|
||||
if !self_bbox.contains_bbox(&original) {
|
||||
panic!("original bounding box is not contained in self");
|
||||
}
|
||||
if !self_bbox.contains_bbox(&padded) {
|
||||
panic!("padded bounding box is not contained in self");
|
||||
}
|
||||
|
||||
let original_top_left = bbox_top_left_usize(&original);
|
||||
let padded_top_left = bbox_top_left_usize(&padded);
|
||||
let origin_offset = bbox_with_origin_usize(original_top_left, padded_top_left);
|
||||
let original_roi_in_padded =
|
||||
bbox_with_top_left_usize(&original, origin_offset.0, origin_offset.1);
|
||||
|
||||
let original_segment = self.roi(original);
|
||||
let mut padded_segment = bbox_zeros_ndarray_2d::<T>(&padded);
|
||||
padded_segment
|
||||
.roi_mut(original_roi_in_padded)
|
||||
.assign(&original_segment);
|
||||
|
||||
(padded, padded_segment)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: bytemuck::Pod + num::Zero> NdRoiZeroPadded<T, ndarray::Ix3> for ndarray::Array3<T> {
|
||||
fn roi_zero_padded(
|
||||
&self,
|
||||
original: bounding_box::Aabb2<usize>,
|
||||
padded: bounding_box::Aabb2<usize>,
|
||||
) -> (bounding_box::Aabb2<usize>, ndarray::Array3<T>) {
|
||||
let self_bbox = bounding_box::Aabb2::from_xywh(0, 0, self.shape()[1], self.shape()[0]);
|
||||
if !self_bbox.contains_bbox(&original) {
|
||||
panic!("original bounding box is not contained in self");
|
||||
}
|
||||
if !self_bbox.contains_bbox(&padded) {
|
||||
panic!("padded bounding box is not contained in self");
|
||||
}
|
||||
|
||||
let original_top_left = bbox_top_left_usize(&original);
|
||||
let padded_top_left = bbox_top_left_usize(&padded);
|
||||
let origin_offset = bbox_with_origin_usize(original_top_left, padded_top_left);
|
||||
let original_roi_in_padded =
|
||||
bbox_with_top_left_usize(&original, origin_offset.0, origin_offset.1);
|
||||
|
||||
let original_segment = self.roi(original);
|
||||
let mut padded_segment = bbox_zeros_ndarray_3d::<T>(&padded, self.len_of(ndarray::Axis(2)));
|
||||
padded_segment
|
||||
.roi_mut(original_roi_in_padded)
|
||||
.assign(&original_segment);
|
||||
|
||||
(padded, padded_segment)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roi_zero_padded() {
|
||||
let arr = ndarray::Array2::<u8>::ones((10, 10));
|
||||
let original = bounding_box::Aabb2::from_xywh(1.0, 1.0, 3.0, 3.0);
|
||||
let clamp = bounding_box::Aabb2::from_xywh(0.0, 0.0, 10.0, 10.0);
|
||||
let padded = original.padding(2.0).clamp(&clamp).unwrap();
|
||||
let padded_cast = bbox_cast_f64_to_usize(&padded);
|
||||
let original_cast = bbox_cast_f64_to_usize(&original);
|
||||
let (padded_result, padded_segment) = arr.roi_zero_padded(original_cast, padded_cast);
|
||||
assert_eq!(padded_result, bounding_box::Aabb2::from_xywh(0, 0, 6, 6));
|
||||
assert_eq!(padded_segment.shape(), &[6, 6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn bbox_clamp_failure_preload() {
|
||||
let segment_mask = ndarray::Array2::<u8>::zeros((512, 512));
|
||||
let og = bounding_box::Aabb2::from_xywh(475.0, 79.625, 37.0, 282.15);
|
||||
let clamp = bounding_box::Aabb2::from_xywh(0.0, 0.0, 512.0, 512.0);
|
||||
let padded = og
|
||||
.scale(nalgebra::Vector2::new(1.2, 1.2))
|
||||
.clamp(&clamp)
|
||||
.unwrap();
|
||||
let padded = bbox_round_f64(&padded);
|
||||
let og_cast = bbox_cast_f64_to_usize(&bbox_round_f64(&og));
|
||||
let padded_cast = bbox_cast_f64_to_usize(&padded);
|
||||
let (_bbox, _segment) = segment_mask.roi_zero_padded(og_cast, padded_cast);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn bbox_clamp_failure_preload_2() {
|
||||
let segment_mask = ndarray::Array2::<u8>::zeros((512, 512));
|
||||
let bbox = bounding_box::Aabb2::from_xywh(354.25, 98.5, 116.25, 413.5);
|
||||
// let padded = bounding_box::Aabb2::from_xywh(342.625, 57.150000000000006, 139.5, 454.85);
|
||||
let clamp = bounding_box::Aabb2::from_xywh(0.0, 0.0, 512.0, 512.0);
|
||||
let padded = bbox
|
||||
.scale(nalgebra::Vector2::new(1.2, 1.2))
|
||||
.clamp(&clamp)
|
||||
.unwrap();
|
||||
let padded = bbox_round_f64(&padded);
|
||||
let bbox_cast = bbox_cast_f64_to_usize(&bbox_round_f64(&bbox));
|
||||
let padded_cast = bbox_cast_f64_to_usize(&padded);
|
||||
let (_bbox, _segment) = segment_mask.roi_zero_padded(bbox_cast, padded_cast);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roi_zero_padded_3d() {
|
||||
let arr = ndarray::Array3::<u8>::ones((10, 10, 3));
|
||||
let original = bounding_box::Aabb2::from_xywh(1.0, 1.0, 3.0, 3.0);
|
||||
let clamp = bounding_box::Aabb2::from_xywh(0.0, 0.0, 10.0, 10.0);
|
||||
let padded = original.padding(2.0).clamp(&clamp).unwrap();
|
||||
let padded_cast = bbox_cast_f64_to_usize(&padded);
|
||||
let original_cast = bbox_cast_f64_to_usize(&original);
|
||||
let (padded_result, padded_segment) = arr.roi_zero_padded(original_cast, padded_cast);
|
||||
assert_eq!(padded_result, bounding_box::Aabb2::from_xywh(0, 0, 6, 6));
|
||||
assert_eq!(padded_segment.shape(), &[6, 6, 3]);
|
||||
}
|
||||
42
patches/ort_env_global_mutex.patch
Normal file
42
patches/ort_env_global_mutex.patch
Normal file
@@ -0,0 +1,42 @@
|
||||
From 83e1dbf52b7695a2966795e0350aaa385d1ba8c8 Mon Sep 17 00:00:00 2001
|
||||
From: "Carson M." <carson@pyke.io>
|
||||
Date: Sun, 22 Jun 2025 23:53:20 -0500
|
||||
Subject: [PATCH] Leak logger mutex
|
||||
|
||||
---
|
||||
onnxruntime/core/common/logging/logging.cc | 8 ++++----
|
||||
1 file changed, 4 insertions(+), 4 deletions(-)
|
||||
|
||||
diff --git a/onnxruntime/core/common/logging/logging.cc b/onnxruntime/core/common/logging/logging.cc
|
||||
index a79e7300cffce..07578fc72ca99 100644
|
||||
--- a/onnxruntime/core/common/logging/logging.cc
|
||||
+++ b/onnxruntime/core/common/logging/logging.cc
|
||||
@@ -64,8 +64,8 @@ LoggingManager* LoggingManager::GetDefaultInstance() {
|
||||
#pragma warning(disable : 26426)
|
||||
#endif
|
||||
|
||||
-static std::mutex& DefaultLoggerMutex() noexcept {
|
||||
- static std::mutex mutex;
|
||||
+static std::mutex* DefaultLoggerMutex() noexcept {
|
||||
+ static std::mutex* mutex = new std::mutex();
|
||||
return mutex;
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ LoggingManager::LoggingManager(std::unique_ptr<ISink> sink, Severity default_min
|
||||
|
||||
// lock mutex to create instance, and enable logging
|
||||
// this matches the mutex usage in Shutdown
|
||||
- std::lock_guard<std::mutex> guard(DefaultLoggerMutex());
|
||||
+ std::lock_guard<std::mutex> guard(*DefaultLoggerMutex());
|
||||
|
||||
if (DefaultLoggerManagerInstance().load() != nullptr) {
|
||||
ORT_THROW("Only one instance of LoggingManager created with InstanceType::Default can exist at any point in time.");
|
||||
@@ -127,7 +127,7 @@ LoggingManager::LoggingManager(std::unique_ptr<ISink> sink, Severity default_min
|
||||
LoggingManager::~LoggingManager() {
|
||||
if (owns_default_logger_) {
|
||||
// lock mutex to reset DefaultLoggerManagerInstance() and free default logger from this instance.
|
||||
- std::lock_guard<std::mutex> guard(DefaultLoggerMutex());
|
||||
+ std::lock_guard<std::mutex> guard(*DefaultLoggerMutex());
|
||||
#if ((__cplusplus >= 201703L) || (defined(_MSVC_LANG) && (_MSVC_LANG >= 201703L)))
|
||||
DefaultLoggerManagerInstance().store(nullptr, std::memory_order_release);
|
||||
#else
|
||||
14
sqlite3-ndarray-math/Cargo.toml
Normal file
14
sqlite3-ndarray-math/Cargo.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "sqlite3-ndarray-math"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "staticlib"]
|
||||
|
||||
[dependencies]
|
||||
ndarray = "0.16.1"
|
||||
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
|
||||
# ndarray-math = { path = "/Users/fs0c131y/Projects/ndarray-math", version = "0.1.0" }
|
||||
ndarray-safetensors = { version = "0.1.0", path = "../ndarray-safetensors" }
|
||||
sqlite-loadable = "0.0.5"
|
||||
61
sqlite3-ndarray-math/src/lib.rs
Normal file
61
sqlite3-ndarray-math/src/lib.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use sqlite_loadable::prelude::*;
|
||||
use sqlite_loadable::{Error, ErrorKind};
|
||||
use sqlite_loadable::{Result, api, define_scalar_function};
|
||||
|
||||
fn cosine_similarity(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
|
||||
#[inline(always)]
|
||||
fn custom_error(err: impl core::error::Error) -> sqlite_loadable::Error {
|
||||
sqlite_loadable::Error::new(sqlite_loadable::ErrorKind::Message(err.to_string()))
|
||||
}
|
||||
|
||||
if values.len() != 2 {
|
||||
return Err(Error::new(ErrorKind::Message(
|
||||
"cosine_similarity requires exactly 2 arguments".to_string(),
|
||||
)));
|
||||
}
|
||||
let array_1 = api::value_blob(values.get(0).expect("1st argument"));
|
||||
let array_2 = api::value_blob(values.get(1).expect("2nd argument"));
|
||||
let array_1_st =
|
||||
ndarray_safetensors::SafeArraysView::from_bytes(array_1).map_err(custom_error)?;
|
||||
let array_2_st =
|
||||
ndarray_safetensors::SafeArraysView::from_bytes(array_2).map_err(custom_error)?;
|
||||
|
||||
let array_view_1 = array_1_st
|
||||
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
||||
.map_err(custom_error)?;
|
||||
let array_view_2 = array_2_st
|
||||
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
||||
.map_err(custom_error)?;
|
||||
|
||||
use ndarray_math::*;
|
||||
let similarity = array_view_1
|
||||
.cosine_similarity(array_view_2)
|
||||
.map_err(custom_error)?;
|
||||
api::result_double(context, similarity as f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn _sqlite3_extension_init(db: *mut sqlite3) -> Result<()> {
|
||||
define_scalar_function(
|
||||
db,
|
||||
"cosine_similarity",
|
||||
2,
|
||||
cosine_similarity,
|
||||
FunctionFlags::DETERMINISTIC,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// Should only be called by underlying SQLite C APIs,
|
||||
/// like sqlite3_auto_extension and sqlite3_cancel_auto_extension.
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn sqlite3_extension_init(
|
||||
db: *mut sqlite3,
|
||||
pz_err_msg: *mut *mut c_char,
|
||||
p_api: *mut sqlite3_api_routines,
|
||||
) -> c_uint {
|
||||
register_entrypoint(db, pz_err_msg, p_api, _sqlite3_extension_init)
|
||||
}
|
||||
213
src/bin/detector-cli/cli.rs
Normal file
213
src/bin/detector-cli/cli.rs
Normal file
@@ -0,0 +1,213 @@
|
||||
use detector::ort_ep;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct Cli {
|
||||
#[clap(subcommand)]
|
||||
pub cmd: SubCommand,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Subcommand)]
|
||||
pub enum SubCommand {
|
||||
#[clap(name = "detect")]
|
||||
Detect(Detect),
|
||||
#[clap(name = "detect-multi")]
|
||||
DetectMulti(DetectMulti),
|
||||
#[clap(name = "query")]
|
||||
Query(Query),
|
||||
#[clap(name = "similar")]
|
||||
Similar(Similar),
|
||||
#[clap(name = "stats")]
|
||||
Stats(Stats),
|
||||
#[clap(name = "compare")]
|
||||
Compare(Compare),
|
||||
#[clap(name = "cluster")]
|
||||
Cluster(Cluster),
|
||||
#[clap(name = "gui")]
|
||||
Gui,
|
||||
#[clap(name = "completions")]
|
||||
Completions { shell: clap_complete::Shell },
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy, PartialEq)]
|
||||
pub enum Models {
|
||||
RetinaFace,
|
||||
Yolo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Executor {
|
||||
Mnn(mnn::ForwardType),
|
||||
Ort(Vec<ort_ep::ExecutionProvider>),
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Detect {
|
||||
#[clap(short, long)]
|
||||
pub model: Option<PathBuf>,
|
||||
#[clap(short = 'M', long, default_value = "retina-face")]
|
||||
pub model_type: Models,
|
||||
#[clap(short, long)]
|
||||
pub output: Option<PathBuf>,
|
||||
#[clap(
|
||||
short = 'p',
|
||||
long,
|
||||
default_value = "cpu",
|
||||
group = "execution_provider",
|
||||
required_unless_present = "mnn_forward_type"
|
||||
)]
|
||||
pub ort_execution_provider: Vec<ort_ep::ExecutionProvider>,
|
||||
#[clap(
|
||||
short = 'f',
|
||||
long,
|
||||
group = "execution_provider",
|
||||
required_unless_present = "ort_execution_provider"
|
||||
)]
|
||||
pub mnn_forward_type: Option<mnn::ForwardType>,
|
||||
#[clap(short, long, default_value_t = 0.8)]
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 0.3)]
|
||||
pub nms_threshold: f32,
|
||||
#[clap(short, long, default_value_t = 8)]
|
||||
pub batch_size: usize,
|
||||
#[clap(short = 'd', long)]
|
||||
pub database: Option<PathBuf>,
|
||||
#[clap(long, default_value = "facenet")]
|
||||
pub model_name: String,
|
||||
#[clap(long)]
|
||||
pub save_to_db: bool,
|
||||
pub image: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct DetectMulti {
|
||||
#[clap(short, long)]
|
||||
pub model: Option<PathBuf>,
|
||||
#[clap(short = 'M', long, default_value = "retina-face")]
|
||||
pub model_type: Models,
|
||||
#[clap(short, long)]
|
||||
pub output_dir: Option<PathBuf>,
|
||||
#[clap(
|
||||
short = 'p',
|
||||
long,
|
||||
default_value = "cpu",
|
||||
group = "execution_provider",
|
||||
required_unless_present = "mnn_forward_type"
|
||||
)]
|
||||
pub ort_execution_provider: Vec<ort_ep::ExecutionProvider>,
|
||||
#[clap(
|
||||
short = 'f',
|
||||
long,
|
||||
group = "execution_provider",
|
||||
required_unless_present = "ort_execution_provider"
|
||||
)]
|
||||
pub mnn_forward_type: Option<mnn::ForwardType>,
|
||||
#[clap(short, long, default_value_t = 0.8)]
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 0.3)]
|
||||
pub nms_threshold: f32,
|
||||
#[clap(short, long, default_value_t = 8)]
|
||||
pub batch_size: usize,
|
||||
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||
pub database: PathBuf,
|
||||
#[clap(long, default_value = "facenet")]
|
||||
pub model_name: String,
|
||||
#[clap(
|
||||
long,
|
||||
help = "Image extensions to process (e.g., jpg,png,jpeg)",
|
||||
default_value = "jpg,jpeg,png,bmp,tiff,webp"
|
||||
)]
|
||||
pub extensions: String,
|
||||
#[clap(help = "Directory containing images to process")]
|
||||
pub input_dir: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Query {
|
||||
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||
pub database: PathBuf,
|
||||
#[clap(short, long)]
|
||||
pub image_id: Option<i64>,
|
||||
#[clap(short, long)]
|
||||
pub face_id: Option<i64>,
|
||||
#[clap(long)]
|
||||
pub show_embeddings: bool,
|
||||
#[clap(long)]
|
||||
pub show_landmarks: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Similar {
|
||||
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||
pub database: PathBuf,
|
||||
#[clap(short, long)]
|
||||
pub face_id: i64,
|
||||
#[clap(short, long, default_value_t = 0.7)]
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 10)]
|
||||
pub limit: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Stats {
|
||||
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||
pub database: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Compare {
|
||||
#[clap(short, long)]
|
||||
pub model: Option<PathBuf>,
|
||||
#[clap(short = 'M', long, default_value = "retina-face")]
|
||||
pub model_type: Models,
|
||||
#[clap(
|
||||
short = 'p',
|
||||
long,
|
||||
default_value = "cpu",
|
||||
group = "execution_provider",
|
||||
required_unless_present = "mnn_forward_type"
|
||||
)]
|
||||
pub ort_execution_provider: Vec<ort_ep::ExecutionProvider>,
|
||||
#[clap(
|
||||
short = 'f',
|
||||
long,
|
||||
group = "execution_provider",
|
||||
required_unless_present = "ort_execution_provider"
|
||||
)]
|
||||
pub mnn_forward_type: Option<mnn::ForwardType>,
|
||||
#[clap(short, long, default_value_t = 0.8)]
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 0.3)]
|
||||
pub nms_threshold: f32,
|
||||
#[clap(short, long, default_value_t = 8)]
|
||||
pub batch_size: usize,
|
||||
#[clap(long, default_value = "facenet")]
|
||||
pub model_name: String,
|
||||
#[clap(help = "First image to compare")]
|
||||
pub image1: PathBuf,
|
||||
#[clap(help = "Second image to compare")]
|
||||
pub image2: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Cluster {
|
||||
#[clap(short = 'd', long, default_value = "face_detections.db")]
|
||||
pub database: PathBuf,
|
||||
#[clap(short, long, default_value_t = 5)]
|
||||
pub clusters: usize,
|
||||
#[clap(short, long, default_value_t = 100)]
|
||||
pub max_iterations: u64,
|
||||
#[clap(short, long, default_value_t = 1e-4)]
|
||||
pub tolerance: f64,
|
||||
#[clap(long, default_value = "facenet")]
|
||||
pub model_name: String,
|
||||
#[clap(short, long)]
|
||||
pub output: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
pub fn completions(shell: clap_complete::Shell) {
|
||||
let mut command = <Cli as clap::CommandFactory>::command();
|
||||
clap_complete::generate(shell, &mut command, "detector", &mut std::io::stdout());
|
||||
}
|
||||
}
|
||||
1063
src/bin/detector-cli/main.rs
Normal file
1063
src/bin/detector-cli/main.rs
Normal file
File diff suppressed because it is too large
Load Diff
19
src/bin/gui.rs
Normal file
19
src/bin/gui.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
use detector::errors::*;
|
||||
fn main() -> Result<()> {
|
||||
// Initialize logging
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("warn,ort=warn")
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
// .with_thread_names(true)
|
||||
.with_target(true)
|
||||
.init();
|
||||
|
||||
// Run the GUI
|
||||
if let Err(e) = detector::gui::run() {
|
||||
eprintln!("GUI error: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
71
src/cli.rs
71
src/cli.rs
@@ -1,71 +0,0 @@
|
||||
use std::path::PathBuf;
|
||||
#[derive(Debug, clap::Parser)]
|
||||
pub struct Cli {
|
||||
#[clap(subcommand)]
|
||||
pub cmd: SubCommand,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Subcommand)]
|
||||
pub enum SubCommand {
|
||||
#[clap(name = "detect")]
|
||||
Detect(Detect),
|
||||
#[clap(name = "list")]
|
||||
List(List),
|
||||
#[clap(name = "completions")]
|
||||
Completions { shell: clap_complete::Shell },
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
||||
pub enum Models {
|
||||
RetinaFace,
|
||||
Yolo,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
||||
pub enum Executor {
|
||||
Mnn,
|
||||
Onnx,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
||||
pub enum OnnxEp {
|
||||
Cpu,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
|
||||
pub enum MnnEp {
|
||||
Cpu,
|
||||
Metal,
|
||||
OpenCL,
|
||||
CoreML,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct Detect {
|
||||
#[clap(short, long)]
|
||||
pub model: Option<PathBuf>,
|
||||
#[clap(short = 'M', long, default_value = "retina-face")]
|
||||
pub model_type: Models,
|
||||
#[clap(short, long)]
|
||||
pub output: Option<PathBuf>,
|
||||
#[clap(short, long, default_value_t = 0.8)]
|
||||
pub threshold: f32,
|
||||
#[clap(short, long, default_value_t = 0.3)]
|
||||
pub nms_threshold: f32,
|
||||
pub image: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct List {}
|
||||
|
||||
impl Cli {
|
||||
pub fn completions(shell: clap_complete::Shell) {
|
||||
let mut command = <Cli as clap::CommandFactory>::command();
|
||||
clap_complete::generate(
|
||||
shell,
|
||||
&mut command,
|
||||
env!("CARGO_BIN_NAME"),
|
||||
&mut std::io::stdout(),
|
||||
);
|
||||
}
|
||||
}
|
||||
734
src/database.rs
Normal file
734
src/database.rs
Normal file
@@ -0,0 +1,734 @@
|
||||
use crate::errors::{Error, Result};
|
||||
use crate::facedet::{FaceDetectionOutput, FaceLandmarks};
|
||||
use bounding_box::Aabb2;
|
||||
use error_stack::ResultExt;
|
||||
use ndarray_math::CosineSimilarity;
|
||||
use rusqlite::{Connection, OptionalExtension, params};
|
||||
use std::path::Path;
|
||||
|
||||
/// Database connection and operations for face detection results
|
||||
pub struct FaceDatabase {
|
||||
conn: Connection,
|
||||
}
|
||||
|
||||
/// Represents a stored image record
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ImageRecord {
|
||||
pub id: i64,
|
||||
pub file_path: String,
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// Represents a stored face detection record
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FaceRecord {
|
||||
pub id: i64,
|
||||
pub image_id: i64,
|
||||
pub bbox_x1: f32,
|
||||
pub bbox_y1: f32,
|
||||
pub bbox_x2: f32,
|
||||
pub bbox_y2: f32,
|
||||
pub confidence: f32,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// Represents stored face landmarks
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LandmarkRecord {
|
||||
pub id: i64,
|
||||
pub face_id: i64,
|
||||
pub left_eye_x: f32,
|
||||
pub left_eye_y: f32,
|
||||
pub right_eye_x: f32,
|
||||
pub right_eye_y: f32,
|
||||
pub nose_x: f32,
|
||||
pub nose_y: f32,
|
||||
pub left_mouth_x: f32,
|
||||
pub left_mouth_y: f32,
|
||||
pub right_mouth_x: f32,
|
||||
pub right_mouth_y: f32,
|
||||
}
|
||||
|
||||
/// Represents a stored face embedding
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbeddingRecord {
|
||||
pub id: i64,
|
||||
pub face_id: i64,
|
||||
pub embedding: ndarray::Array1<f32>,
|
||||
pub model_name: String,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
impl FaceDatabase {
|
||||
/// Create a new database connection and initialize tables
|
||||
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
|
||||
let conn = Connection::open(db_path).change_context(Error)?;
|
||||
// Temporarily disable extension loading for clustering demo
|
||||
// unsafe {
|
||||
// let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
|
||||
// conn.load_extension(
|
||||
// "/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
|
||||
// None::<&str>,
|
||||
// )
|
||||
// .change_context(Error)?;
|
||||
// }
|
||||
let db = Self { conn };
|
||||
db.create_tables()?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
/// Create an in-memory database for testing
|
||||
pub fn in_memory() -> Result<Self> {
|
||||
let conn = Connection::open_in_memory().change_context(Error)?;
|
||||
let db = Self { conn };
|
||||
db.create_tables()?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
/// Create all necessary database tables
|
||||
fn create_tables(&self) -> Result<()> {
|
||||
// Images table
|
||||
self.conn
|
||||
.execute(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_path TEXT NOT NULL UNIQUE,
|
||||
width INTEGER NOT NULL,
|
||||
height INTEGER NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"#,
|
||||
[],
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
// Faces table
|
||||
self.conn
|
||||
.execute(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS faces (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
image_id INTEGER NOT NULL,
|
||||
bbox_x1 REAL NOT NULL,
|
||||
bbox_y1 REAL NOT NULL,
|
||||
bbox_x2 REAL NOT NULL,
|
||||
bbox_y2 REAL NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
|
||||
)
|
||||
"#,
|
||||
[],
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
// Landmarks table
|
||||
self.conn
|
||||
.execute(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS landmarks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
face_id INTEGER NOT NULL,
|
||||
left_eye_x REAL NOT NULL,
|
||||
left_eye_y REAL NOT NULL,
|
||||
right_eye_x REAL NOT NULL,
|
||||
right_eye_y REAL NOT NULL,
|
||||
nose_x REAL NOT NULL,
|
||||
nose_y REAL NOT NULL,
|
||||
left_mouth_x REAL NOT NULL,
|
||||
left_mouth_y REAL NOT NULL,
|
||||
right_mouth_x REAL NOT NULL,
|
||||
right_mouth_y REAL NOT NULL,
|
||||
FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE
|
||||
)
|
||||
"#,
|
||||
[],
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
// Embeddings table
|
||||
self.conn
|
||||
.execute(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS embeddings (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
face_id INTEGER NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE
|
||||
)
|
||||
"#,
|
||||
[],
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
// Create indexes for better performance
|
||||
self.conn
|
||||
.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_faces_image_id ON faces (image_id)",
|
||||
[],
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
self.conn
|
||||
.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_landmarks_face_id ON landmarks (face_id)",
|
||||
[],
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
self.conn
|
||||
.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_embeddings_face_id ON embeddings (face_id)",
|
||||
[],
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Store image metadata and return the image ID
|
||||
pub fn store_image(&self, file_path: &str, width: u32, height: u32) -> Result<i64> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)")
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(stmt
|
||||
.insert(params![file_path, width, height])
|
||||
.change_context(Error)?)
|
||||
}
|
||||
|
||||
/// Store face detection results
|
||||
pub fn store_face_detections(
|
||||
&self,
|
||||
image_id: i64,
|
||||
detection_output: &FaceDetectionOutput,
|
||||
) -> Result<Vec<i64>> {
|
||||
let mut face_ids = Vec::new();
|
||||
|
||||
for (i, bbox) in detection_output.bbox.iter().enumerate() {
|
||||
let confidence = detection_output.confidence.get(i).copied().unwrap_or(0.0);
|
||||
|
||||
let face_id = self.store_face(image_id, bbox, confidence)?;
|
||||
face_ids.push(face_id);
|
||||
|
||||
// Store landmarks if available
|
||||
if let Some(landmarks) = detection_output.landmark.get(i) {
|
||||
self.store_landmarks(face_id, landmarks)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(face_ids)
|
||||
}
|
||||
|
||||
/// Store a single face detection
|
||||
pub fn store_face(&self, image_id: i64, bbox: &Aabb2<usize>, confidence: f32) -> Result<i64> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#"
|
||||
INSERT INTO faces (image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||||
"#,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(stmt
|
||||
.insert(params![
|
||||
image_id,
|
||||
bbox.x1() as f32,
|
||||
bbox.y1() as f32,
|
||||
bbox.x2() as f32,
|
||||
bbox.y2() as f32,
|
||||
confidence
|
||||
])
|
||||
.change_context(Error)?)
|
||||
}
|
||||
|
||||
/// Store face landmarks
|
||||
pub fn store_landmarks(&self, face_id: i64, landmarks: &FaceLandmarks) -> Result<i64> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#"
|
||||
INSERT INTO landmarks
|
||||
(face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y,
|
||||
nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)
|
||||
"#,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(stmt
|
||||
.insert(params![
|
||||
face_id,
|
||||
landmarks.left_eye.x,
|
||||
landmarks.left_eye.y,
|
||||
landmarks.right_eye.x,
|
||||
landmarks.right_eye.y,
|
||||
landmarks.nose.x,
|
||||
landmarks.nose.y,
|
||||
landmarks.left_mouth.x,
|
||||
landmarks.left_mouth.y,
|
||||
landmarks.right_mouth.x,
|
||||
landmarks.right_mouth.y,
|
||||
])
|
||||
.change_context(Error)?)
|
||||
}
|
||||
|
||||
/// Store face embeddings
|
||||
pub fn store_embeddings(
|
||||
&self,
|
||||
face_ids: &[i64],
|
||||
embeddings: &[ndarray::Array2<f32>],
|
||||
model_name: &str,
|
||||
) -> Result<Vec<i64>> {
|
||||
let mut embedding_ids = Vec::new();
|
||||
|
||||
for (face_idx, embedding_batch) in embeddings.iter().enumerate() {
|
||||
for (batch_idx, embedding_row) in embedding_batch.rows().into_iter().enumerate() {
|
||||
let global_idx = face_idx * embedding_batch.nrows() + batch_idx;
|
||||
|
||||
if global_idx >= face_ids.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
let face_id = face_ids[global_idx];
|
||||
let embedding_id =
|
||||
self.store_single_embedding(face_id, embedding_row, model_name)?;
|
||||
embedding_ids.push(embedding_id);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embedding_ids)
|
||||
}
|
||||
|
||||
/// Store a single embedding
|
||||
pub fn store_single_embedding(
|
||||
&self,
|
||||
face_id: i64,
|
||||
embedding: ndarray::ArrayView1<f32>,
|
||||
model_name: &str,
|
||||
) -> Result<i64> {
|
||||
let safe_arrays =
|
||||
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)])
|
||||
.change_context(Error)?;
|
||||
|
||||
let embedding_bytes = safe_arrays.serialize().change_context(Error)?;
|
||||
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)")
|
||||
.change_context(Error)?;
|
||||
|
||||
stmt.execute(params![face_id, embedding_bytes, model_name])
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(self.conn.last_insert_rowid())
|
||||
}
|
||||
|
||||
/// Get image by ID
|
||||
pub fn get_image(&self, image_id: i64) -> Result<Option<ImageRecord>> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("SELECT id, file_path, width, height, created_at FROM images WHERE id = ?1")
|
||||
.change_context(Error)?;
|
||||
|
||||
let result = stmt
|
||||
.query_row(params![image_id], |row| {
|
||||
Ok(ImageRecord {
|
||||
id: row.get(0)?,
|
||||
file_path: row.get(1)?,
|
||||
width: row.get(2)?,
|
||||
height: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})
|
||||
.optional()
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get all faces for an image
|
||||
pub fn get_faces_for_image(&self, image_id: i64) -> Result<Vec<FaceRecord>> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#"
|
||||
SELECT id, image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence, created_at
|
||||
FROM faces WHERE image_id = ?1
|
||||
"#,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
let face_iter = stmt
|
||||
.query_map(params![image_id], |row| {
|
||||
Ok(FaceRecord {
|
||||
id: row.get(0)?,
|
||||
image_id: row.get(1)?,
|
||||
bbox_x1: row.get(2)?,
|
||||
bbox_y1: row.get(3)?,
|
||||
bbox_x2: row.get(4)?,
|
||||
bbox_y2: row.get(5)?,
|
||||
confidence: row.get(6)?,
|
||||
created_at: row.get(7)?,
|
||||
})
|
||||
})
|
||||
.change_context(Error)?;
|
||||
|
||||
let mut faces = Vec::new();
|
||||
for face in face_iter {
|
||||
faces.push(face.change_context(Error)?);
|
||||
}
|
||||
|
||||
Ok(faces)
|
||||
}
|
||||
|
||||
/// Get landmarks for a face
|
||||
pub fn get_landmarks(&self, face_id: i64) -> Result<Option<LandmarkRecord>> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#"
|
||||
SELECT id, face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y,
|
||||
nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y
|
||||
FROM landmarks WHERE face_id = ?1
|
||||
"#,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
let result = stmt
|
||||
.query_row(params![face_id], |row| {
|
||||
Ok(LandmarkRecord {
|
||||
id: row.get(0)?,
|
||||
face_id: row.get(1)?,
|
||||
left_eye_x: row.get(2)?,
|
||||
left_eye_y: row.get(3)?,
|
||||
right_eye_x: row.get(4)?,
|
||||
right_eye_y: row.get(5)?,
|
||||
nose_x: row.get(6)?,
|
||||
nose_y: row.get(7)?,
|
||||
left_mouth_x: row.get(8)?,
|
||||
left_mouth_y: row.get(9)?,
|
||||
right_mouth_x: row.get(10)?,
|
||||
right_mouth_y: row.get(11)?,
|
||||
})
|
||||
})
|
||||
.optional()
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get embeddings for a face
|
||||
pub fn get_embeddings(&self, face_id: i64) -> Result<Vec<EmbeddingRecord>> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
"SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE face_id = ?1",
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
let embedding_iter = stmt
|
||||
.query_map(params![face_id], |row| {
|
||||
let embedding_bytes: Vec<u8> = row.get(2)?;
|
||||
let embedding: ndarray::Array1<f32> = {
|
||||
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
|
||||
.change_context(Error)
|
||||
// .change_context(Error)?
|
||||
.unwrap();
|
||||
sf.tensor::<f32, ndarray::Ix1>("embedding")
|
||||
// .change_context(Error)?
|
||||
.unwrap()
|
||||
.to_owned()
|
||||
};
|
||||
|
||||
Ok(EmbeddingRecord {
|
||||
id: row.get(0)?,
|
||||
face_id: row.get(1)?,
|
||||
embedding,
|
||||
model_name: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})
|
||||
.change_context(Error)?;
|
||||
|
||||
let mut embeddings = Vec::new();
|
||||
for embedding in embedding_iter {
|
||||
embeddings.push(embedding.change_context(Error)?);
|
||||
}
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
pub fn get_image_for_face(&self, face_id: i64) -> Result<Option<ImageRecord>> {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#"
|
||||
SELECT images.id, images.file_path, images.width, images.height, images.created_at
|
||||
FROM images
|
||||
JOIN faces ON faces.image_id = images.id
|
||||
WHERE faces.id = ?1
|
||||
"#,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
let result = stmt
|
||||
.query_row(params![face_id], |row| {
|
||||
Ok(ImageRecord {
|
||||
id: row.get(0)?,
|
||||
file_path: row.get(1)?,
|
||||
width: row.get(2)?,
|
||||
height: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})
|
||||
.optional()
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get database statistics
|
||||
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
|
||||
let images: usize = self
|
||||
.conn
|
||||
.query_row("SELECT COUNT(*) FROM images", [], |row| row.get(0))
|
||||
.change_context(Error)?;
|
||||
|
||||
let faces: usize = self
|
||||
.conn
|
||||
.query_row("SELECT COUNT(*) FROM faces", [], |row| row.get(0))
|
||||
.change_context(Error)?;
|
||||
|
||||
let landmarks: usize = self
|
||||
.conn
|
||||
.query_row("SELECT COUNT(*) FROM landmarks", [], |row| row.get(0))
|
||||
.change_context(Error)?;
|
||||
|
||||
let embeddings: usize = self
|
||||
.conn
|
||||
.query_row("SELECT COUNT(*) FROM embeddings", [], |row| row.get(0))
|
||||
.change_context(Error)?;
|
||||
|
||||
Ok((images, faces, landmarks, embeddings))
|
||||
}
|
||||
|
||||
/// Find similar faces based on cosine similarity of embeddings
|
||||
/// Return ids and similarity scores of similar faces
|
||||
pub fn find_similar_faces(
|
||||
&self,
|
||||
embedding: &ndarray::Array1<f32>,
|
||||
threshold: f32,
|
||||
limit: usize,
|
||||
) -> Result<Vec<(i64, f32)>> {
|
||||
// Serialize the query embedding to bytes
|
||||
let embedding_bytes =
|
||||
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
|
||||
.change_context(Error)?
|
||||
.serialize()
|
||||
.change_context(Error)?;
|
||||
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#" SELECT face_id, cosine_similarity(?1, embedding) as similarity
|
||||
FROM embeddings
|
||||
WHERE cosine_similarity(?1, embedding) >= ?2
|
||||
ORDER BY similarity DESC
|
||||
LIMIT ?3"#,
|
||||
)
|
||||
.change_context(Error)?;
|
||||
|
||||
let result = stmt
|
||||
.query_map(params![embedding_bytes, threshold, limit], |row| {
|
||||
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
|
||||
})
|
||||
.change_context(Error)?
|
||||
.map(|r| r.change_context(Error))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
// let mut results = Vec::new();
|
||||
// for result in result_iter {
|
||||
// results.push(result.change_context(Error)?);
|
||||
// }
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn query_similarity(&self, embedding: &ndarray::Array1<f32>) {
|
||||
let embedding_bytes =
|
||||
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
|
||||
.change_context(Error)
|
||||
.unwrap()
|
||||
.serialize()
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
r#"
|
||||
SELECT face_id,
|
||||
cosine_similarity(?1, embedding)
|
||||
FROM embeddings
|
||||
"#,
|
||||
)
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
|
||||
let result_iter = stmt
|
||||
.query_map(params![embedding_bytes], |row| {
|
||||
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
|
||||
})
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
|
||||
for result in result_iter {
|
||||
println!("{:?}", result);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all embeddings for a specific model
|
||||
pub fn get_all_embeddings(&self, model_name: Option<&str>) -> Result<Vec<EmbeddingRecord>> {
|
||||
let mut embeddings = Vec::new();
|
||||
|
||||
if let Some(model) = model_name {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE model_name = ?1"
|
||||
).change_context(Error)?;
|
||||
|
||||
let embedding_iter = stmt
|
||||
.query_map(params![model], |row| {
|
||||
let embedding_bytes: Vec<u8> = row.get(2)?;
|
||||
let embedding: ndarray::Array1<f32> = {
|
||||
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
sf.tensor::<f32, ndarray::Ix1>("embedding")
|
||||
.unwrap()
|
||||
.to_owned()
|
||||
};
|
||||
|
||||
Ok(EmbeddingRecord {
|
||||
id: row.get(0)?,
|
||||
face_id: row.get(1)?,
|
||||
embedding,
|
||||
model_name: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})
|
||||
.change_context(Error)?;
|
||||
|
||||
for embedding in embedding_iter {
|
||||
embeddings.push(embedding.change_context(Error)?);
|
||||
}
|
||||
} else {
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("SELECT id, face_id, embedding, model_name, created_at FROM embeddings")
|
||||
.change_context(Error)?;
|
||||
|
||||
let embedding_iter = stmt
|
||||
.query_map([], |row| {
|
||||
let embedding_bytes: Vec<u8> = row.get(2)?;
|
||||
let embedding: ndarray::Array1<f32> = {
|
||||
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
|
||||
.change_context(Error)
|
||||
.unwrap();
|
||||
sf.tensor::<f32, ndarray::Ix1>("embedding")
|
||||
.unwrap()
|
||||
.to_owned()
|
||||
};
|
||||
|
||||
Ok(EmbeddingRecord {
|
||||
id: row.get(0)?,
|
||||
face_id: row.get(1)?,
|
||||
embedding,
|
||||
model_name: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})
|
||||
.change_context(Error)?;
|
||||
|
||||
for embedding in embedding_iter {
|
||||
embeddings.push(embedding.change_context(Error)?);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
|
||||
use rusqlite::functions::*;
|
||||
db.create_scalar_function(
|
||||
"cosine_similarity",
|
||||
2,
|
||||
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
|
||||
move |ctx| {
|
||||
if ctx.len() != 2 {
|
||||
return Err(rusqlite::Error::UserFunctionError(
|
||||
"cosine_similarity requires exactly 2 arguments".into(),
|
||||
));
|
||||
}
|
||||
let array_1 = ctx.get_raw(0).as_blob()?;
|
||||
let array_2 = ctx.get_raw(1).as_blob()?;
|
||||
|
||||
let array_1_st = ndarray_safetensors::SafeArraysView::from_bytes(array_1)
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
let array_2_st = ndarray_safetensors::SafeArraysView::from_bytes(array_2)
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
|
||||
let array_view_1 = array_1_st
|
||||
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
let array_view_2 = array_2_st
|
||||
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
|
||||
let similarity = array_view_1
|
||||
.cosine_similarity(array_view_2)
|
||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||
|
||||
Ok(similarity)
|
||||
},
|
||||
)
|
||||
.change_context(Error)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_database_creation() -> Result<()> {
|
||||
let db = FaceDatabase::in_memory()?;
|
||||
let (images, faces, landmarks, embeddings) = db.get_stats()?;
|
||||
assert_eq!(images, 0);
|
||||
assert_eq!(faces, 0);
|
||||
assert_eq!(landmarks, 0);
|
||||
assert_eq!(embeddings, 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_store_and_retrieve_image() -> Result<()> {
|
||||
let db = FaceDatabase::in_memory()?;
|
||||
let image_id = db.store_image("/path/to/image.jpg", 800, 600)?;
|
||||
|
||||
let image = db.get_image(image_id)?.unwrap();
|
||||
assert_eq!(image.file_path, "/path/to/image.jpg");
|
||||
assert_eq!(image.width, 800);
|
||||
assert_eq!(image.height, 600);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,2 +1,8 @@
|
||||
pub mod retinaface;
|
||||
pub mod yolo;
|
||||
|
||||
// Re-export common types and traits
|
||||
pub use retinaface::{
|
||||
FaceDetectionConfig, FaceDetectionModelOutput, FaceDetectionOutput,
|
||||
FaceDetectionProcessedOutput, FaceDetector, FaceLandmarks,
|
||||
};
|
||||
|
||||
@@ -1,67 +1,88 @@
|
||||
pub mod mnn;
|
||||
pub mod ort;
|
||||
|
||||
use crate::errors::*;
|
||||
use bounding_box::{Aabb2, nms::nms};
|
||||
use error_stack::ResultExt;
|
||||
use mnn_bridge::ndarray::*;
|
||||
use nalgebra::{Point2, Vector2};
|
||||
use ndarray_resize::NdFir;
|
||||
use std::path::Path;
|
||||
|
||||
/// Configuration for face detection postprocessing
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FaceDetectionConfig {
|
||||
anchor_sizes: Vec<Vector2<usize>>,
|
||||
steps: Vec<usize>,
|
||||
variance: Vec<f32>,
|
||||
threshold: f32,
|
||||
nms_threshold: f32,
|
||||
/// Minimum confidence to keep a detection
|
||||
pub threshold: f32,
|
||||
/// NMS threshold for suppressing overlapping boxes
|
||||
pub nms_threshold: f32,
|
||||
/// Variances for bounding box decoding
|
||||
pub variances: [f32; 2],
|
||||
/// The step size (stride) for each feature map
|
||||
pub steps: Vec<usize>,
|
||||
/// The minimum anchor sizes for each feature map
|
||||
pub min_sizes: Vec<Vec<usize>>,
|
||||
/// Whether to clip bounding boxes to the image dimensions
|
||||
pub clamp: bool,
|
||||
/// Input image width (used for anchor generation)
|
||||
pub input_width: usize,
|
||||
/// Input image height (used for anchor generation)
|
||||
pub input_height: usize,
|
||||
}
|
||||
|
||||
impl FaceDetectionConfig {
|
||||
pub fn with_min_sizes(mut self, min_sizes: Vec<Vector2<usize>>) -> Self {
|
||||
self.anchor_sizes = min_sizes;
|
||||
self
|
||||
}
|
||||
pub fn with_steps(mut self, steps: Vec<usize>) -> Self {
|
||||
self.steps = steps;
|
||||
self
|
||||
}
|
||||
pub fn with_variance(mut self, variance: Vec<f32>) -> Self {
|
||||
self.variance = variance;
|
||||
self
|
||||
}
|
||||
pub fn with_threshold(mut self, threshold: f32) -> Self {
|
||||
self.threshold = threshold;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_nms_threshold(mut self, nms_threshold: f32) -> Self {
|
||||
self.nms_threshold = nms_threshold;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_variances(mut self, variances: [f32; 2]) -> Self {
|
||||
self.variances = variances;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_steps(mut self, steps: Vec<usize>) -> Self {
|
||||
self.steps = steps;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_min_sizes(mut self, min_sizes: Vec<Vec<usize>>) -> Self {
|
||||
self.min_sizes = min_sizes;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_clip(mut self, clip: bool) -> Self {
|
||||
self.clamp = clip;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_input_width(mut self, input_width: usize) -> Self {
|
||||
self.input_width = input_width;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_input_height(mut self, input_height: usize) -> Self {
|
||||
self.input_height = input_height;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FaceDetectionConfig {
|
||||
fn default() -> Self {
|
||||
FaceDetectionConfig {
|
||||
anchor_sizes: vec![
|
||||
Vector2::new(16, 32),
|
||||
Vector2::new(64, 128),
|
||||
Vector2::new(256, 512),
|
||||
],
|
||||
steps: vec![8, 16, 32],
|
||||
variance: vec![0.1, 0.2],
|
||||
threshold: 0.8,
|
||||
Self {
|
||||
threshold: 0.5,
|
||||
nms_threshold: 0.4,
|
||||
variances: [0.1, 0.2],
|
||||
steps: vec![8, 16, 32],
|
||||
min_sizes: vec![vec![16, 32], vec![64, 128], vec![256, 512]],
|
||||
clamp: true,
|
||||
input_width: 1024,
|
||||
input_height: 1024,
|
||||
}
|
||||
}
|
||||
}
|
||||
pub struct FaceDetection {
|
||||
handle: mnn_sync::SessionHandle,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FaceDetectionModelOutput {
|
||||
pub bbox: ndarray::Array3<f32>,
|
||||
pub confidence: ndarray::Array3<f32>,
|
||||
pub landmark: ndarray::Array3<f32>,
|
||||
}
|
||||
|
||||
/// Represents the 5 facial landmarks detected by RetinaFace
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
@@ -73,6 +94,13 @@ pub struct FaceLandmarks {
|
||||
pub right_mouth: Point2<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FaceDetectionModelOutput {
|
||||
pub bbox: ndarray::Array3<f32>,
|
||||
pub confidence: ndarray::Array3<f32>,
|
||||
pub landmark: ndarray::Array3<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FaceDetectionProcessedOutput {
|
||||
pub bbox: Vec<Aabb2<f32>>,
|
||||
@@ -87,85 +115,133 @@ pub struct FaceDetectionOutput {
|
||||
pub landmark: Vec<FaceLandmarks>,
|
||||
}
|
||||
|
||||
impl FaceDetectionModelOutput {
|
||||
pub fn postprocess(self, config: &FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> {
|
||||
/// Raw model outputs that can be converted to FaceDetectionModelOutput
|
||||
pub trait IntoModelOutput {
|
||||
fn into_model_output(self) -> Result<FaceDetectionModelOutput>;
|
||||
}
|
||||
|
||||
/// Generate anchors for RetinaFace model
|
||||
pub fn generate_anchors(config: &FaceDetectionConfig) -> ndarray::Array2<f32> {
|
||||
let mut anchors = Vec::new();
|
||||
for (k, &step) in config.steps.iter().enumerate() {
|
||||
let feature_size = 1024 / step;
|
||||
let min_sizes = config.anchor_sizes[k];
|
||||
let sizes = [min_sizes.x, min_sizes.y];
|
||||
for i in 0..feature_size {
|
||||
for j in 0..feature_size {
|
||||
for &size in &sizes {
|
||||
let cx = (j as f32 + 0.5) * step as f32 / 1024.0;
|
||||
let cy = (i as f32 + 0.5) * step as f32 / 1024.0;
|
||||
let s_k = size as f32 / 1024.0;
|
||||
anchors.push((cx, cy, s_k, s_k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut boxes = Vec::new();
|
||||
let mut scores = Vec::new();
|
||||
let mut landmarks = Vec::new();
|
||||
let var0 = config.variance[0];
|
||||
let var1 = config.variance[1];
|
||||
let bbox_data = self.bbox;
|
||||
let conf_data = self.confidence;
|
||||
let landmark_data = self.landmark;
|
||||
let num_priors = bbox_data.shape()[1];
|
||||
for idx in 0..num_priors {
|
||||
let dx = bbox_data[[0, idx, 0]];
|
||||
let dy = bbox_data[[0, idx, 1]];
|
||||
let dw = bbox_data[[0, idx, 2]];
|
||||
let dh = bbox_data[[0, idx, 3]];
|
||||
let (anchor_cx, anchor_cy, anchor_w, anchor_h) = anchors[idx];
|
||||
let pred_cx = anchor_cx + dx * var0 * anchor_w;
|
||||
let pred_cy = anchor_cy + dy * var0 * anchor_h;
|
||||
let pred_w = anchor_w * (dw * var1).exp();
|
||||
let pred_h = anchor_h * (dh * var1).exp();
|
||||
let x_min = pred_cx - pred_w / 2.0;
|
||||
let y_min = pred_cy - pred_h / 2.0;
|
||||
let x_max = pred_cx + pred_w / 2.0;
|
||||
let y_max = pred_cy + pred_h / 2.0;
|
||||
let score = conf_data[[0, idx, 1]];
|
||||
if score > config.threshold {
|
||||
boxes.push(Aabb2::from_x1y1x2y2(x_min, y_min, x_max, y_max));
|
||||
scores.push(score);
|
||||
|
||||
let left_eye_x = landmark_data[[0, idx, 0]] * anchor_w * var0 + anchor_cx;
|
||||
let left_eye_y = landmark_data[[0, idx, 1]] * anchor_h * var0 + anchor_cy;
|
||||
|
||||
let right_eye_x = landmark_data[[0, idx, 2]] * anchor_w * var0 + anchor_cx;
|
||||
let right_eye_y = landmark_data[[0, idx, 3]] * anchor_h * var0 + anchor_cy;
|
||||
|
||||
let nose_x = landmark_data[[0, idx, 4]] * anchor_w * var0 + anchor_cx;
|
||||
let nose_y = landmark_data[[0, idx, 5]] * anchor_h * var0 + anchor_cy;
|
||||
|
||||
let left_mouth_x = landmark_data[[0, idx, 6]] * anchor_w * var0 + anchor_cx;
|
||||
let left_mouth_y = landmark_data[[0, idx, 7]] * anchor_h * var0 + anchor_cy;
|
||||
|
||||
let right_mouth_x = landmark_data[[0, idx, 8]] * anchor_w * var0 + anchor_cx;
|
||||
let right_mouth_y = landmark_data[[0, idx, 9]] * anchor_h * var0 + anchor_cy;
|
||||
|
||||
landmarks.push(FaceLandmarks {
|
||||
left_eye: Point2::new(left_eye_x, left_eye_y),
|
||||
right_eye: Point2::new(right_eye_x, right_eye_y),
|
||||
nose: Point2::new(nose_x, nose_y),
|
||||
left_mouth: Point2::new(left_mouth_x, left_mouth_y),
|
||||
right_mouth: Point2::new(right_mouth_x, right_mouth_y),
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(FaceDetectionProcessedOutput {
|
||||
bbox: boxes,
|
||||
confidence: scores,
|
||||
landmarks,
|
||||
let feature_maps: Vec<(usize, usize)> = config
|
||||
.steps
|
||||
.iter()
|
||||
.map(|&step| {
|
||||
(
|
||||
(config.input_height as f32 / step as f32).ceil() as usize,
|
||||
(config.input_width as f32 / step as f32).ceil() as usize,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for (k, f) in feature_maps.iter().enumerate() {
|
||||
let min_sizes = &config.min_sizes[k];
|
||||
for i in 0..f.0 {
|
||||
for j in 0..f.1 {
|
||||
for &min_size in min_sizes {
|
||||
let s_kx = min_size as f32 / config.input_width as f32;
|
||||
let s_ky = min_size as f32 / config.input_height as f32;
|
||||
let dense_cx =
|
||||
(j as f32 + 0.5) * config.steps[k] as f32 / config.input_width as f32;
|
||||
let dense_cy =
|
||||
(i as f32 + 0.5) * config.steps[k] as f32 / config.input_height as f32;
|
||||
anchors.push([
|
||||
dense_cx - s_kx / 2.,
|
||||
dense_cy - s_ky / 2.,
|
||||
dense_cx + s_kx / 2.,
|
||||
dense_cy + s_ky / 2.,
|
||||
]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ndarray::Array2::from_shape_vec((anchors.len(), 4), anchors.into_iter().flatten().collect())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
impl FaceDetectionModelOutput {
|
||||
pub fn postprocess(self, config: &FaceDetectionConfig) -> Result<FaceDetectionProcessedOutput> {
|
||||
use ndarray::s;
|
||||
|
||||
let priors = generate_anchors(config);
|
||||
|
||||
let scores = self.confidence.slice(s![0, .., 1]);
|
||||
let boxes = self.bbox.slice(s![0, .., ..]);
|
||||
let landmarks_raw = self.landmark.slice(s![0, .., ..]);
|
||||
|
||||
// let mut decoded_boxes = Vec::new();
|
||||
// let mut decoded_landmarks = Vec::new();
|
||||
// let mut confidences = Vec::new();
|
||||
|
||||
let (decoded_boxes, decoded_landmarks, confidences) = (0..priors.shape()[0])
|
||||
.filter(|&i| scores[i] > config.threshold)
|
||||
.map(|i| {
|
||||
let prior = priors.row(i);
|
||||
let loc = boxes.row(i);
|
||||
let landm = landmarks_raw.row(i);
|
||||
|
||||
// Decode bounding box
|
||||
let prior_cx = (prior[0] + prior[2]) / 2.0;
|
||||
let prior_cy = (prior[1] + prior[3]) / 2.0;
|
||||
let prior_w = prior[2] - prior[0];
|
||||
let prior_h = prior[3] - prior[1];
|
||||
|
||||
let var = config.variances;
|
||||
let cx = prior_cx + loc[0] * var[0] * prior_w;
|
||||
let cy = prior_cy + loc[1] * var[0] * prior_h;
|
||||
let w = prior_w * (loc[2] * var[1]).exp();
|
||||
let h = prior_h * (loc[3] * var[1]).exp();
|
||||
|
||||
let xmin = cx - w / 2.0;
|
||||
let ymin = cy - h / 2.0;
|
||||
let xmax = cx + w / 2.0;
|
||||
let ymax = cy + h / 2.0;
|
||||
|
||||
let mut bbox =
|
||||
Aabb2::from_min_max_vertices(Point2::new(xmin, ymin), Point2::new(xmax, ymax));
|
||||
if config.clamp {
|
||||
bbox = bbox.component_clamp(0.0, 1.0);
|
||||
}
|
||||
|
||||
// Decode landmarks
|
||||
let points: [Point2<f32>; 5] = (0..5)
|
||||
.map(|j| {
|
||||
Point2::new(
|
||||
prior_cx + landm[j * 2] * var[0] * prior_w,
|
||||
prior_cy + landm[j * 2 + 1] * var[0] * prior_h,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let landmarks = FaceLandmarks {
|
||||
left_eye: points[0],
|
||||
right_eye: points[1],
|
||||
nose: points[2],
|
||||
left_mouth: points[3],
|
||||
right_mouth: points[4],
|
||||
};
|
||||
|
||||
(bbox, landmarks, scores[i])
|
||||
})
|
||||
.fold(
|
||||
(Vec::new(), Vec::new(), Vec::new()),
|
||||
|(mut boxes, mut landmarks, mut confs), (bbox, landmark, conf)| {
|
||||
boxes.push(bbox);
|
||||
landmarks.push(landmark);
|
||||
confs.push(conf);
|
||||
(boxes, landmarks, confs)
|
||||
},
|
||||
);
|
||||
Ok(FaceDetectionProcessedOutput {
|
||||
bbox: decoded_boxes,
|
||||
confidence: confidences,
|
||||
landmarks: decoded_landmarks,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn print(&self, limit: usize) {
|
||||
tracing::info!("Detected {} faces", self.bbox.shape()[1]);
|
||||
|
||||
@@ -189,49 +265,16 @@ impl FaceDetectionModelOutput {
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceDetection {
|
||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let model = std::fs::read(path)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to read model file")?;
|
||||
Self::new_from_bytes(&model)
|
||||
}
|
||||
|
||||
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
|
||||
tracing::info!("Loading face detection model from bytes");
|
||||
let mut model = mnn::Interpreter::from_bytes(model)
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?;
|
||||
model.set_session_mode(mnn::SessionMode::Release);
|
||||
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
|
||||
let sc = mnn::ScheduleConfig::new()
|
||||
.with_type(mnn::ForwardType::CPU)
|
||||
.with_backend_config(bc);
|
||||
tracing::info!("Creating session handle for face detection model");
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session handle")?;
|
||||
Ok(FaceDetection { handle })
|
||||
}
|
||||
|
||||
pub fn detect_faces(
|
||||
&self,
|
||||
image: ndarray::Array3<u8>,
|
||||
config: FaceDetectionConfig,
|
||||
) -> Result<FaceDetectionOutput> {
|
||||
let (height, width, _channels) = image.dim();
|
||||
let output = self
|
||||
.run_models(image)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
// denormalize the bounding boxes
|
||||
let factor = Vector2::new(width as f32, height as f32);
|
||||
let mut processed = output
|
||||
.postprocess(&config)
|
||||
.attach_printable("Failed to postprocess")?;
|
||||
|
||||
/// Apply Non-Maximum Suppression and convert to final output format
|
||||
pub fn apply_nms_and_finalize(
|
||||
processed: FaceDetectionProcessedOutput,
|
||||
config: &FaceDetectionConfig,
|
||||
image_size: (usize, usize), // (width, height)
|
||||
) -> Result<FaceDetectionOutput> {
|
||||
use itertools::Itertools;
|
||||
|
||||
let factor = Vector2::new(image_size.0 as f32, image_size.1 as f32);
|
||||
|
||||
let (boxes, scores, landmarks): (Vec<_>, Vec<_>, Vec<_>) = processed
|
||||
.bbox
|
||||
.iter()
|
||||
@@ -242,7 +285,8 @@ impl FaceDetection {
|
||||
.map(|((b, s), l)| (b, s, l))
|
||||
.multiunzip();
|
||||
|
||||
let keep_indices = nms(&boxes, &scores, config.threshold, config.nms_threshold);
|
||||
let keep_indices =
|
||||
nms(&boxes, &scores, config.threshold, config.nms_threshold).change_context(Error)?;
|
||||
|
||||
let bboxes = boxes
|
||||
.into_iter()
|
||||
@@ -268,81 +312,29 @@ impl FaceDetection {
|
||||
confidence,
|
||||
landmark,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_models(&self, image: ndarray::Array3<u8>) -> Result<FaceDetectionModelOutput> {
|
||||
#[rustfmt::skip]
|
||||
use ::tap::*;
|
||||
/// Common trait for face detection backends
|
||||
pub trait FaceDetector {
|
||||
/// Run inference on the model and return raw outputs
|
||||
fn run_model(&mut self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput>;
|
||||
|
||||
/// Detect faces with full pipeline including postprocessing
|
||||
fn detect_faces(
|
||||
&mut self,
|
||||
image: ndarray::ArrayView3<u8>,
|
||||
config: &FaceDetectionConfig,
|
||||
) -> Result<FaceDetectionOutput> {
|
||||
let (height, width, _channels) = image.dim();
|
||||
let output = self
|
||||
.handle
|
||||
.run(move |sr| {
|
||||
let mut resized = image
|
||||
.fast_resize(1024, 1024, None)
|
||||
.change_context(mnn::ErrorKind::TensorError)?
|
||||
.mapv(|f| f as f32)
|
||||
.tap_mut(|arr| {
|
||||
arr.axis_iter_mut(ndarray::Axis(2))
|
||||
.zip([104, 117, 123])
|
||||
.for_each(|(mut array, pixel)| {
|
||||
let pixel = pixel as f32;
|
||||
array.map_inplace(|v| *v -= pixel);
|
||||
});
|
||||
})
|
||||
.permuted_axes((2, 0, 1))
|
||||
.insert_axis(ndarray::Axis(0))
|
||||
.as_standard_layout()
|
||||
.into_owned();
|
||||
let tensor = resized
|
||||
.as_mnn_tensor_mut()
|
||||
.attach_printable("Failed to convert ndarray to mnn tensor")
|
||||
.change_context(mnn::error::ErrorKind::TensorError)?;
|
||||
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
|
||||
let (intptr, session) = sr.both_mut();
|
||||
tracing::trace!("Copying input tensor to host");
|
||||
unsafe {
|
||||
let mut input = intptr.input_unresized::<f32>(session, "input")?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
intptr.resize_tensor_by_nchw::<mnn::View<&mut f32>, _>(
|
||||
input.view_mut(),
|
||||
1,
|
||||
3,
|
||||
1024,
|
||||
1024,
|
||||
);
|
||||
}
|
||||
intptr.resize_session(session);
|
||||
let mut input = intptr.input::<f32>(session, "input")?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
input.copy_from_host_tensor(tensor.view())?;
|
||||
.run_model(image)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
|
||||
tracing::info!("Running face detection session");
|
||||
intptr.run_session(&session)?;
|
||||
let output_tensor = intptr
|
||||
.output::<f32>(&session, "bbox")?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray()
|
||||
.to_owned();
|
||||
tracing::trace!("Output Bbox: \t\t{:?}", output_tensor.shape());
|
||||
let output_confidence = intptr
|
||||
.output::<f32>(&session, "confidence")?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray::<ndarray::Ix3>()
|
||||
.to_owned();
|
||||
tracing::trace!("Output Confidence: \t{:?}", output_confidence.shape());
|
||||
let output_landmark = intptr
|
||||
.output::<f32>(&session, "landmark")?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray::<ndarray::Ix3>()
|
||||
.to_owned();
|
||||
tracing::trace!("Output Landmark: \t{:?}", output_landmark.shape());
|
||||
Ok(FaceDetectionModelOutput {
|
||||
bbox: output_tensor,
|
||||
confidence: output_confidence,
|
||||
landmark: output_landmark,
|
||||
})
|
||||
})
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)?;
|
||||
Ok(output)
|
||||
let processed = output
|
||||
.postprocess(&config)
|
||||
.attach_printable("Failed to postprocess")?;
|
||||
|
||||
apply_nms_and_finalize(processed, &config, (width, height))
|
||||
}
|
||||
}
|
||||
|
||||
146
src/facedet/retinaface/mnn.rs
Normal file
146
src/facedet/retinaface/mnn.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
use crate::errors::*;
|
||||
use crate::facedet::*;
|
||||
use error_stack::ResultExt;
|
||||
use mnn_bridge::ndarray::*;
|
||||
use ndarray_resize::NdFir;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FaceDetection {
|
||||
handle: mnn_sync::SessionHandle,
|
||||
}
|
||||
|
||||
pub struct FaceDetectionBuilder {
|
||||
schedule_config: Option<mnn::ScheduleConfig>,
|
||||
backend_config: Option<mnn::BackendConfig>,
|
||||
model: mnn::Interpreter,
|
||||
}
|
||||
|
||||
impl FaceDetectionBuilder {
|
||||
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||
Ok(Self {
|
||||
schedule_config: None,
|
||||
backend_config: None,
|
||||
model: mnn::Interpreter::from_bytes(model.as_ref())
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
||||
self.schedule_config
|
||||
.get_or_insert_default()
|
||||
.set_type(forward_type);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
||||
self.schedule_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
||||
self.backend_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<FaceDetection> {
|
||||
let model = self.model;
|
||||
let sc = self.schedule_config.unwrap_or_default();
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session handle")?;
|
||||
Ok(FaceDetection { handle })
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceDetection {
|
||||
pub fn builder<T: AsRef<[u8]>>(
|
||||
model: T,
|
||||
) -> std::result::Result<FaceDetectionBuilder, error_stack::Report<crate::errors::Error>> {
|
||||
FaceDetectionBuilder::new(model)
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceDetector for FaceDetection {
|
||||
fn run_model(&mut self, image: ndarray::ArrayView3<u8>) -> Result<FaceDetectionModelOutput> {
|
||||
#[rustfmt::skip]
|
||||
let mut resized = image
|
||||
.fast_resize(1024, 1024, None)
|
||||
.change_context(Error)?
|
||||
.mapv(|f| f as f32);
|
||||
|
||||
// Apply mean subtraction: [104, 117, 123]
|
||||
resized
|
||||
.axis_iter_mut(ndarray::Axis(2))
|
||||
.zip([104, 117, 123])
|
||||
.for_each(|(mut array, pixel)| {
|
||||
let pixel = pixel as f32;
|
||||
array.map_inplace(|v| *v -= pixel);
|
||||
});
|
||||
|
||||
let mut resized = resized
|
||||
.permuted_axes((2, 0, 1))
|
||||
.insert_axis(ndarray::Axis(0))
|
||||
.as_standard_layout()
|
||||
.into_owned();
|
||||
|
||||
use ::tap::*;
|
||||
let output = self
|
||||
.handle
|
||||
.run(move |sr| {
|
||||
let tensor = resized
|
||||
.as_mnn_tensor_mut()
|
||||
.attach_printable("Failed to convert ndarray to mnn tensor")
|
||||
.change_context(mnn::error::ErrorKind::TensorError)?;
|
||||
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
|
||||
let (intptr, session) = sr.both_mut();
|
||||
tracing::trace!("Copying input tensor to host");
|
||||
unsafe {
|
||||
let mut input = intptr.input_unresized::<f32>(session, "input")?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
intptr.resize_tensor_by_nchw::<mnn::View<&mut f32>, _>(
|
||||
input.view_mut(),
|
||||
1,
|
||||
3,
|
||||
1024,
|
||||
1024,
|
||||
);
|
||||
}
|
||||
intptr.resize_session(session);
|
||||
let mut input = intptr.input::<f32>(session, "input")?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
input.copy_from_host_tensor(tensor.view())?;
|
||||
|
||||
tracing::info!("Running face detection session");
|
||||
intptr.run_session(&session)?;
|
||||
let output_tensor = intptr
|
||||
.output::<f32>(&session, "bbox")?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray()
|
||||
.to_owned();
|
||||
tracing::trace!("Output Bbox: \t\t{:?}", output_tensor.shape());
|
||||
let output_confidence = intptr
|
||||
.output::<f32>(&session, "confidence")?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray::<ndarray::Ix3>()
|
||||
.to_owned();
|
||||
tracing::trace!("Output Confidence: \t{:?}", output_confidence.shape());
|
||||
let output_landmark = intptr
|
||||
.output::<f32>(&session, "landmark")?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray::<ndarray::Ix3>()
|
||||
.to_owned();
|
||||
tracing::trace!("Output Landmark: \t{:?}", output_landmark.shape());
|
||||
Ok(FaceDetectionModelOutput {
|
||||
bbox: output_tensor,
|
||||
confidence: output_confidence,
|
||||
landmark: output_landmark,
|
||||
})
|
||||
})
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)?;
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
256
src/facedet/retinaface/ort.rs
Normal file
256
src/facedet/retinaface/ort.rs
Normal file
@@ -0,0 +1,256 @@
|
||||
use crate::errors::*;
|
||||
use crate::facedet::*;
|
||||
use crate::ort_ep::*;
|
||||
use error_stack::ResultExt;
|
||||
use ndarray_resize::NdFir;
|
||||
use ort::{
|
||||
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
|
||||
session::{Session, builder::GraphOptimizationLevel},
|
||||
value::Tensor,
|
||||
};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FaceDetection {
|
||||
session: Session,
|
||||
}
|
||||
|
||||
pub struct FaceDetectionBuilder {
|
||||
model_data: Vec<u8>,
|
||||
execution_providers: Option<Vec<ExecutionProviderDispatch>>,
|
||||
intra_threads: Option<usize>,
|
||||
inter_threads: Option<usize>,
|
||||
}
|
||||
|
||||
impl FaceDetectionBuilder {
|
||||
pub fn new(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
||||
Ok(Self {
|
||||
model_data: model.as_ref().to_vec(),
|
||||
execution_providers: None,
|
||||
intra_threads: None,
|
||||
inter_threads: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self {
|
||||
let execution_providers: Vec<ExecutionProviderDispatch> = providers
|
||||
.as_ref()
|
||||
.iter()
|
||||
.filter_map(|provider| provider.to_dispatch())
|
||||
.collect();
|
||||
|
||||
if !execution_providers.is_empty() {
|
||||
self.execution_providers = Some(execution_providers);
|
||||
} else {
|
||||
tracing::warn!("No valid execution providers found, falling back to CPU");
|
||||
self.execution_providers = Some(vec![CPUExecutionProvider::default().build()]);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_intra_threads(mut self, threads: usize) -> Self {
|
||||
self.intra_threads = Some(threads);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_inter_threads(mut self, threads: usize) -> Self {
|
||||
self.inter_threads = Some(threads);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> crate::errors::Result<FaceDetection> {
|
||||
let mut session_builder = Session::builder()
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session builder")?;
|
||||
|
||||
// Set execution providers
|
||||
if let Some(providers) = self.execution_providers {
|
||||
session_builder = session_builder
|
||||
.with_execution_providers(providers)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set execution providers")?;
|
||||
} else {
|
||||
// Default to CPU
|
||||
session_builder = session_builder
|
||||
.with_execution_providers([CPUExecutionProvider::default().build()])
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set default CPU execution provider")?;
|
||||
}
|
||||
|
||||
// Set threading options
|
||||
if let Some(threads) = self.intra_threads {
|
||||
session_builder = session_builder
|
||||
.with_intra_threads(threads)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set intra threads")?;
|
||||
}
|
||||
|
||||
if let Some(threads) = self.inter_threads {
|
||||
session_builder = session_builder
|
||||
.with_inter_threads(threads)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set inter threads")?;
|
||||
}
|
||||
|
||||
// Set optimization level
|
||||
session_builder = session_builder
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set optimization level")?;
|
||||
|
||||
// Create session from model bytes
|
||||
let session = session_builder
|
||||
.commit_from_memory(&self.model_data)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create ORT session from model bytes")?;
|
||||
|
||||
tracing::info!("Successfully created ORT RetinaFace session");
|
||||
|
||||
Ok(FaceDetection { session })
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceDetection {
|
||||
pub fn builder<T: AsRef<[u8]>>(
|
||||
model: T,
|
||||
) -> std::result::Result<FaceDetectionBuilder, error_stack::Report<crate::errors::Error>> {
|
||||
FaceDetectionBuilder::new(model)
|
||||
}
|
||||
|
||||
pub fn new(path: impl AsRef<Path>) -> crate::errors::Result<Self> {
|
||||
let model = std::fs::read(path)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to read model file")?;
|
||||
Self::new_from_bytes(&model)
|
||||
}
|
||||
|
||||
pub fn new_from_bytes(model: &[u8]) -> crate::errors::Result<Self> {
|
||||
tracing::info!("Loading ORT RetinaFace model from bytes");
|
||||
Self::builder(model)?.build()
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceDetector for FaceDetection {
|
||||
fn run_model(
|
||||
&mut self,
|
||||
image: ndarray::ArrayView3<u8>,
|
||||
) -> crate::errors::Result<FaceDetectionModelOutput> {
|
||||
// Resize image to 1024x1024
|
||||
let mut resized = image
|
||||
.fast_resize(1024, 1024, None)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to resize image")?
|
||||
.mapv(|f| f as f32);
|
||||
|
||||
// Apply mean subtraction: [104, 117, 123] for BGR format
|
||||
resized
|
||||
.axis_iter_mut(ndarray::Axis(2))
|
||||
.zip([104.0, 117.0, 123.0])
|
||||
.for_each(|(mut array, mean)| {
|
||||
array.map_inplace(|v| *v -= mean);
|
||||
});
|
||||
|
||||
// Convert from HWC to NCHW format (add batch dimension and transpose)
|
||||
let input_tensor = resized
|
||||
.permuted_axes((2, 0, 1))
|
||||
.insert_axis(ndarray::Axis(0))
|
||||
.as_standard_layout()
|
||||
.into_owned();
|
||||
|
||||
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
|
||||
|
||||
// Create ORT input tensor
|
||||
let input_value = Tensor::from_array(input_tensor)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create input tensor")?;
|
||||
|
||||
// Run inference
|
||||
tracing::debug!("Running ORT RetinaFace inference");
|
||||
let outputs = self
|
||||
.session
|
||||
.run(ort::inputs!["input" => input_value])
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to run inference")?;
|
||||
|
||||
// Extract outputs by name
|
||||
let bbox_output = outputs
|
||||
.get("bbox")
|
||||
.ok_or(Error)
|
||||
.attach_printable("Missing bbox output from model")?
|
||||
.try_extract_tensor::<f32>()
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to extract bbox tensor")?;
|
||||
|
||||
let confidence_output = outputs
|
||||
.get("confidence")
|
||||
.ok_or(Error)
|
||||
.attach_printable("Missing confidence output from model")?
|
||||
.try_extract_tensor::<f32>()
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to extract confidence tensor")?;
|
||||
|
||||
let landmark_output = outputs
|
||||
.get("landmark")
|
||||
.ok_or(Error)
|
||||
.attach_printable("Missing landmark output from model")?
|
||||
.try_extract_tensor::<f32>()
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to extract landmark tensor")?;
|
||||
|
||||
// Get tensor shapes and data
|
||||
let (bbox_shape, bbox_data) = bbox_output;
|
||||
let (confidence_shape, confidence_data) = confidence_output;
|
||||
let (landmark_shape, landmark_data) = landmark_output;
|
||||
|
||||
tracing::trace!(
|
||||
"Output shapes - bbox: {:?}, confidence: {:?}, landmark: {:?}",
|
||||
bbox_shape,
|
||||
confidence_shape,
|
||||
landmark_shape
|
||||
);
|
||||
|
||||
// Convert to ndarray format
|
||||
let bbox_dims = bbox_shape.as_ref();
|
||||
let confidence_dims = confidence_shape.as_ref();
|
||||
let landmark_dims = landmark_shape.as_ref();
|
||||
|
||||
let bbox_array = ndarray::Array3::from_shape_vec(
|
||||
(
|
||||
bbox_dims[0] as usize,
|
||||
bbox_dims[1] as usize,
|
||||
bbox_dims[2] as usize,
|
||||
),
|
||||
bbox_data.to_vec(),
|
||||
)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create bbox ndarray")?;
|
||||
|
||||
let confidence_array = ndarray::Array3::from_shape_vec(
|
||||
(
|
||||
confidence_dims[0] as usize,
|
||||
confidence_dims[1] as usize,
|
||||
confidence_dims[2] as usize,
|
||||
),
|
||||
confidence_data.to_vec(),
|
||||
)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create confidence ndarray")?;
|
||||
|
||||
let landmark_array = ndarray::Array3::from_shape_vec(
|
||||
(
|
||||
landmark_dims[0] as usize,
|
||||
landmark_dims[1] as usize,
|
||||
landmark_dims[2] as usize,
|
||||
),
|
||||
landmark_data.to_vec(),
|
||||
)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create landmark ndarray")?;
|
||||
|
||||
Ok(FaceDetectionModelOutput {
|
||||
bbox: bbox_array,
|
||||
confidence: confidence_array,
|
||||
landmark: landmark_array,
|
||||
})
|
||||
}
|
||||
}
|
||||
35
src/faceembed.rs
Normal file
35
src/faceembed.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
pub mod facenet;
|
||||
|
||||
// Re-export common types and traits
|
||||
pub use facenet::FaceNetEmbedder;
|
||||
pub use facenet::{FaceEmbedding, FaceEmbeddingConfig, IntoEmbeddings};
|
||||
|
||||
// Convenience type aliases for different backends
|
||||
pub use facenet::mnn::EmbeddingGenerator as MnnEmbeddingGenerator;
|
||||
pub use facenet::ort::EmbeddingGenerator as OrtEmbeddingGenerator;
|
||||
|
||||
use crate::errors::*;
|
||||
use ndarray::{Array2, ArrayView4};
|
||||
|
||||
pub mod preprocessing {
|
||||
use ndarray::*;
|
||||
pub fn preprocess(faces: ArrayView4<u8>) -> Array4<f32> {
|
||||
let mut owned = faces.as_standard_layout().mapv(|v| v as f32).to_owned();
|
||||
owned.axis_iter_mut(Axis(0)).for_each(|mut image| {
|
||||
let mean = image.mean().unwrap_or(0.0);
|
||||
let std = image.std(0.0);
|
||||
if std > 0.0 {
|
||||
image.mapv_inplace(|x| (x - mean) / std);
|
||||
} else {
|
||||
image.mapv_inplace(|x| (x - 127.5) / 128.0)
|
||||
}
|
||||
});
|
||||
owned
|
||||
}
|
||||
}
|
||||
|
||||
/// Common trait for face embedding backends - maintained for backward compatibility
|
||||
pub trait FaceEmbedder {
|
||||
/// Generate embeddings for a batch of face images
|
||||
fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
|
||||
}
|
||||
209
src/faceembed/facenet.rs
Normal file
209
src/faceembed/facenet.rs
Normal file
@@ -0,0 +1,209 @@
|
||||
pub mod mnn;
|
||||
pub mod ort;
|
||||
|
||||
use crate::errors::*;
|
||||
use error_stack::ResultExt;
|
||||
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
||||
use ndarray_math::{CosineSimilarity, EuclideanDistance};
|
||||
|
||||
/// Configuration for face embedding processing
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FaceEmbeddingConfig {
|
||||
/// Input image width expected by the model
|
||||
pub input_width: usize,
|
||||
/// Input image height expected by the model
|
||||
pub input_height: usize,
|
||||
/// Whether to normalize embeddings to unit vectors
|
||||
pub normalize: bool,
|
||||
}
|
||||
|
||||
impl FaceEmbeddingConfig {
|
||||
pub fn with_input_size(mut self, width: usize, height: usize) -> Self {
|
||||
self.input_width = width;
|
||||
self.input_height = height;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_normalization(mut self, normalize: bool) -> Self {
|
||||
self.normalize = normalize;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FaceEmbeddingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
input_width: 320,
|
||||
input_height: 320,
|
||||
normalize: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a face embedding vector
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FaceEmbedding {
|
||||
/// The embedding vector
|
||||
pub vector: Array1<f32>,
|
||||
/// Optional confidence score for the embedding quality
|
||||
pub confidence: Option<f32>,
|
||||
}
|
||||
|
||||
impl FaceEmbedding {
|
||||
pub fn new(vector: Array1<f32>) -> Self {
|
||||
Self {
|
||||
vector,
|
||||
confidence: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_confidence(mut self, confidence: f32) -> Self {
|
||||
self.confidence = Some(confidence);
|
||||
self
|
||||
}
|
||||
|
||||
/// Calculate cosine similarity with another embedding
|
||||
pub fn cosine_similarity(&self, other: &FaceEmbedding) -> f32 {
|
||||
self.vector.cosine_similarity(&other.vector).unwrap_or(0.0)
|
||||
}
|
||||
|
||||
/// Calculate Euclidean distance with another embedding
|
||||
pub fn euclidean_distance(&self, other: &FaceEmbedding) -> f32 {
|
||||
self.vector
|
||||
.euclidean_distance(other.vector.view())
|
||||
.unwrap_or(f32::INFINITY)
|
||||
}
|
||||
|
||||
/// Normalize the embedding vector to unit length
|
||||
pub fn normalize(&mut self) {
|
||||
let norm = self.vector.mapv(|x| x * x).sum().sqrt();
|
||||
if norm > 0.0 {
|
||||
self.vector.mapv_inplace(|x| x / norm);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the dimensionality of the embedding
|
||||
pub fn dimension(&self) -> usize {
|
||||
self.vector.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Raw model outputs that can be converted to embeddings
|
||||
pub trait IntoEmbeddings {
|
||||
fn into_embeddings(self, config: &FaceEmbeddingConfig) -> Result<Vec<FaceEmbedding>>;
|
||||
}
|
||||
|
||||
impl IntoEmbeddings for Array2<f32> {
|
||||
fn into_embeddings(self, config: &FaceEmbeddingConfig) -> Result<Vec<FaceEmbedding>> {
|
||||
let mut embeddings = Vec::new();
|
||||
|
||||
for row in self.rows() {
|
||||
let mut vector = row.to_owned();
|
||||
|
||||
if config.normalize {
|
||||
let norm = vector.mapv(|x| x * x).sum().sqrt();
|
||||
if norm > 0.0 {
|
||||
vector.mapv_inplace(|x| x / norm);
|
||||
}
|
||||
}
|
||||
|
||||
embeddings.push(FaceEmbedding::new(vector));
|
||||
}
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
/// Common trait for face embedding backends
|
||||
pub trait FaceNetEmbedder {
|
||||
/// Generate embeddings for a batch of face images
|
||||
fn run_model(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>>;
|
||||
|
||||
/// Generate embeddings with full pipeline including postprocessing
|
||||
fn generate_embeddings(
|
||||
&mut self,
|
||||
faces: ArrayView4<u8>,
|
||||
config: FaceEmbeddingConfig,
|
||||
) -> Result<Vec<FaceEmbedding>> {
|
||||
let raw_output = self
|
||||
.run_model(faces)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to generate embeddings")?;
|
||||
|
||||
raw_output
|
||||
.into_embeddings(&config)
|
||||
.attach_printable("Failed to process embeddings")
|
||||
}
|
||||
|
||||
/// Generate a single embedding from a single face image
|
||||
fn generate_embedding(
|
||||
&mut self,
|
||||
face: ArrayView3<u8>,
|
||||
config: FaceEmbeddingConfig,
|
||||
) -> Result<FaceEmbedding> {
|
||||
// Add batch dimension
|
||||
let face_batch = face.insert_axis(ndarray::Axis(0));
|
||||
let embeddings = self.generate_embeddings(face_batch.view(), config)?;
|
||||
|
||||
embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or(Error)
|
||||
.attach_printable("No embedding generated for input face")
|
||||
}
|
||||
}
|
||||
|
||||
/// Utility functions for embedding processing
|
||||
pub mod utils {
|
||||
use super::*;
|
||||
|
||||
/// Compute pairwise cosine similarities between two sets of embeddings
|
||||
pub fn pairwise_cosine_similarities(
|
||||
embeddings1: &[FaceEmbedding],
|
||||
embeddings2: &[FaceEmbedding],
|
||||
) -> Array2<f32> {
|
||||
let n1 = embeddings1.len();
|
||||
let n2 = embeddings2.len();
|
||||
let mut similarities = Array2::zeros((n1, n2));
|
||||
|
||||
for (i, emb1) in embeddings1.iter().enumerate() {
|
||||
for (j, emb2) in embeddings2.iter().enumerate() {
|
||||
similarities[(i, j)] = emb1.cosine_similarity(emb2);
|
||||
}
|
||||
}
|
||||
|
||||
similarities
|
||||
}
|
||||
|
||||
/// Find the best matching embedding from a gallery for each query
|
||||
pub fn find_best_matches(
|
||||
queries: &[FaceEmbedding],
|
||||
gallery: &[FaceEmbedding],
|
||||
) -> Vec<(usize, f32)> {
|
||||
let similarities = pairwise_cosine_similarities(queries, gallery);
|
||||
let mut best_matches = Vec::new();
|
||||
|
||||
for i in 0..queries.len() {
|
||||
let row = similarities.row(i);
|
||||
let (best_idx, best_score) = row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.unwrap();
|
||||
best_matches.push((best_idx, *best_score));
|
||||
}
|
||||
|
||||
best_matches
|
||||
}
|
||||
|
||||
/// Filter embeddings by minimum quality threshold
|
||||
pub fn filter_by_confidence(
|
||||
embeddings: Vec<FaceEmbedding>,
|
||||
min_confidence: f32,
|
||||
) -> Vec<FaceEmbedding> {
|
||||
embeddings
|
||||
.into_iter()
|
||||
.filter(|emb| emb.confidence.map_or(true, |conf| conf >= min_confidence))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
127
src/faceembed/facenet/mnn.rs
Normal file
127
src/faceembed/facenet/mnn.rs
Normal file
@@ -0,0 +1,127 @@
|
||||
use crate::errors::*;
|
||||
use crate::faceembed::facenet::FaceNetEmbedder;
|
||||
use mnn_bridge::ndarray::*;
|
||||
use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EmbeddingGenerator {
|
||||
handle: mnn_sync::SessionHandle,
|
||||
}
|
||||
pub struct EmbeddingGeneratorBuilder {
|
||||
schedule_config: Option<mnn::ScheduleConfig>,
|
||||
backend_config: Option<mnn::BackendConfig>,
|
||||
model: mnn::Interpreter,
|
||||
}
|
||||
|
||||
impl EmbeddingGeneratorBuilder {
|
||||
pub fn new(model: impl AsRef<[u8]>) -> Result<Self> {
|
||||
Ok(Self {
|
||||
schedule_config: None,
|
||||
backend_config: None,
|
||||
model: mnn::Interpreter::from_bytes(model.as_ref())
|
||||
.map_err(|e| e.into_inner())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to load model from bytes")?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_forward_type(mut self, forward_type: mnn::ForwardType) -> Self {
|
||||
self.schedule_config
|
||||
.get_or_insert_default()
|
||||
.set_type(forward_type);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_schedule_config(mut self, config: mnn::ScheduleConfig) -> Self {
|
||||
self.schedule_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_backend_config(mut self, config: mnn::BackendConfig) -> Self {
|
||||
self.backend_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<EmbeddingGenerator> {
|
||||
let model = self.model;
|
||||
let sc = self.schedule_config.unwrap_or_default();
|
||||
let handle = mnn_sync::SessionHandle::new(model, sc)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session handle")?;
|
||||
Ok(EmbeddingGenerator { handle })
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingGenerator {
|
||||
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
||||
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
|
||||
|
||||
pub fn builder<T: AsRef<[u8]>>(
|
||||
model: T,
|
||||
) -> std::result::Result<EmbeddingGeneratorBuilder, Report<Error>> {
|
||||
EmbeddingGeneratorBuilder::new(model)
|
||||
}
|
||||
|
||||
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||
let tensor = crate::faceembed::preprocessing::preprocess(face);
|
||||
let shape: [usize; 4] = tensor.dim().into();
|
||||
let shape = shape.map(|f| f as i32);
|
||||
let output = self
|
||||
.handle
|
||||
.run(move |sr| {
|
||||
let tensor = tensor
|
||||
.as_mnn_tensor()
|
||||
.attach_printable("Failed to convert ndarray to mnn tensor")
|
||||
.change_context(mnn::ErrorKind::TensorError)?;
|
||||
tracing::trace!("Image Tensor shape: {:?}", tensor.shape());
|
||||
let (intptr, session) = sr.both_mut();
|
||||
tracing::trace!("Copying input tensor to host");
|
||||
let needs_resize = unsafe {
|
||||
let mut input = intptr.input_unresized::<f32>(session, Self::INPUT_NAME)?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
if *input.shape() != shape {
|
||||
tracing::trace!("Resizing input tensor to shape: {:?}", shape);
|
||||
// input.resize(shape);
|
||||
intptr.resize_tensor(input.view_mut(), shape);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
if needs_resize {
|
||||
tracing::trace!("Resized input tensor to shape: {:?}", shape);
|
||||
let now = std::time::Instant::now();
|
||||
intptr.resize_session(session);
|
||||
tracing::trace!("Session resized in {:?}", now.elapsed());
|
||||
}
|
||||
let mut input = intptr.input::<f32>(session, Self::INPUT_NAME)?;
|
||||
tracing::trace!("Input shape: {:?}", input.shape());
|
||||
input.copy_from_host_tensor(tensor.view())?;
|
||||
|
||||
tracing::info!("Running face detection session");
|
||||
intptr.run_session(&session)?;
|
||||
let output_tensor = intptr
|
||||
.output::<f32>(&session, Self::OUTPUT_NAME)?
|
||||
.create_host_tensor_from_device(true)
|
||||
.as_ndarray()
|
||||
.to_owned();
|
||||
Ok(output_tensor)
|
||||
})
|
||||
.change_context(Error)?;
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceNetEmbedder for EmbeddingGenerator {
|
||||
fn run_model(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||
self.run_models(faces)
|
||||
}
|
||||
}
|
||||
|
||||
// Main trait implementation for backward compatibility
|
||||
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
|
||||
fn run_models(&mut self, faces: ArrayView4<u8>) -> Result<Array2<f32>> {
|
||||
EmbeddingGenerator::run_models(self, faces)
|
||||
}
|
||||
}
|
||||
207
src/faceembed/facenet/ort.rs
Normal file
207
src/faceembed/facenet/ort.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
use crate::errors::*;
|
||||
use crate::faceembed::facenet::FaceNetEmbedder;
|
||||
use crate::ort_ep::*;
|
||||
use error_stack::ResultExt;
|
||||
use ndarray::{Array2, ArrayView4};
|
||||
use ort::{
|
||||
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
|
||||
session::{Session, builder::GraphOptimizationLevel},
|
||||
value::Tensor,
|
||||
};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EmbeddingGenerator {
|
||||
session: Session,
|
||||
}
|
||||
|
||||
pub struct EmbeddingGeneratorBuilder {
|
||||
model_data: Vec<u8>,
|
||||
execution_providers: Option<Vec<ExecutionProviderDispatch>>,
|
||||
intra_threads: Option<usize>,
|
||||
inter_threads: Option<usize>,
|
||||
}
|
||||
|
||||
impl EmbeddingGeneratorBuilder {
|
||||
pub fn new(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
||||
Ok(Self {
|
||||
model_data: model.as_ref().to_vec(),
|
||||
execution_providers: None,
|
||||
intra_threads: None,
|
||||
inter_threads: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_execution_providers(mut self, providers: impl AsRef<[ExecutionProvider]>) -> Self {
|
||||
let execution_providers: Vec<ExecutionProviderDispatch> = providers
|
||||
.as_ref()
|
||||
.iter()
|
||||
.filter_map(|provider| provider.to_dispatch())
|
||||
.collect();
|
||||
|
||||
if !execution_providers.is_empty() {
|
||||
self.execution_providers = Some(execution_providers);
|
||||
} else {
|
||||
tracing::warn!("No valid execution providers found, falling back to CPU");
|
||||
self.execution_providers = Some(vec![CPUExecutionProvider::default().build()]);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_intra_threads(mut self, threads: usize) -> Self {
|
||||
self.intra_threads = Some(threads);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_inter_threads(mut self, threads: usize) -> Self {
|
||||
self.inter_threads = Some(threads);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> crate::errors::Result<EmbeddingGenerator> {
|
||||
let mut session_builder = Session::builder()
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create session builder")?;
|
||||
|
||||
// Set execution providers
|
||||
if let Some(providers) = self.execution_providers {
|
||||
session_builder = session_builder
|
||||
.with_execution_providers(providers)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set execution providers")?;
|
||||
} else {
|
||||
// Default to CPU
|
||||
session_builder = session_builder
|
||||
.with_execution_providers([CPUExecutionProvider::default().build()])
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set default CPU execution provider")?;
|
||||
}
|
||||
|
||||
// Set threading options
|
||||
if let Some(threads) = self.intra_threads {
|
||||
session_builder = session_builder
|
||||
.with_intra_threads(threads)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set intra threads")?;
|
||||
}
|
||||
|
||||
if let Some(threads) = self.inter_threads {
|
||||
session_builder = session_builder
|
||||
.with_inter_threads(threads)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set inter threads")?;
|
||||
}
|
||||
|
||||
// Set optimization level
|
||||
session_builder = session_builder
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to set optimization level")?;
|
||||
|
||||
// Create session from model bytes
|
||||
let session = session_builder
|
||||
.commit_from_memory(&self.model_data)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create ORT session from model bytes")?;
|
||||
|
||||
tracing::info!("Successfully created ORT FaceNet session");
|
||||
|
||||
Ok(EmbeddingGenerator { session })
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingGenerator {
|
||||
const INPUT_NAME: &'static str = "serving_default_input_6:0";
|
||||
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
|
||||
|
||||
pub fn builder<T: AsRef<[u8]>>(
|
||||
model: T,
|
||||
) -> std::result::Result<EmbeddingGeneratorBuilder, error_stack::Report<crate::errors::Error>>
|
||||
{
|
||||
EmbeddingGeneratorBuilder::new(model)
|
||||
}
|
||||
|
||||
pub fn new(path: impl AsRef<Path>) -> crate::errors::Result<Self> {
|
||||
let model = std::fs::read(path)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to read model file")?;
|
||||
Self::new_from_bytes(&model)
|
||||
}
|
||||
|
||||
pub fn new_from_bytes(model: impl AsRef<[u8]>) -> crate::errors::Result<Self> {
|
||||
tracing::info!("Loading ORT face embedding model from bytes");
|
||||
Self::builder(model)?.build()
|
||||
}
|
||||
|
||||
pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||
// Convert input from u8 to f32 and normalize to [0, 1] range
|
||||
let input_tensor = crate::faceembed::preprocessing::preprocess(faces);
|
||||
|
||||
// face_array = np.asarray(face_resized, 'float32')
|
||||
// mean, std = face_array.mean(), face_array.std()
|
||||
// face_normalized = (face_array - mean) / std
|
||||
// let input_tensor = faces.mean()
|
||||
|
||||
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());
|
||||
|
||||
// Create ORT input tensor
|
||||
let input_value = Tensor::from_array(input_tensor)
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create input tensor")?;
|
||||
|
||||
// Run inference
|
||||
tracing::debug!("Running ORT FaceNet inference");
|
||||
let outputs = self
|
||||
.session
|
||||
.run(ort::inputs![Self::INPUT_NAME => input_value])
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to run inference")?;
|
||||
|
||||
// Extract output tensor
|
||||
let output = outputs
|
||||
.get(Self::OUTPUT_NAME)
|
||||
.ok_or(Error)
|
||||
.attach_printable("Missing output from FaceNet model")?
|
||||
.try_extract_tensor::<f32>()
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to extract output tensor")?;
|
||||
|
||||
let (output_shape, output_data) = output;
|
||||
|
||||
tracing::trace!("Output shape: {:?}", output_shape);
|
||||
|
||||
// Convert to ndarray format
|
||||
let output_dims = output_shape.as_ref();
|
||||
|
||||
// FaceNet typically outputs embeddings as [batch_size, embedding_dim]
|
||||
let batch_size = output_dims[0] as usize;
|
||||
let embedding_dim = output_dims[1] as usize;
|
||||
|
||||
let output_array =
|
||||
ndarray::Array2::from_shape_vec((batch_size, embedding_dim), output_data.to_vec())
|
||||
.change_context(Error)
|
||||
.attach_printable("Failed to create output ndarray")?;
|
||||
|
||||
tracing::trace!(
|
||||
"Generated embeddings with shape: {:?}",
|
||||
output_array.shape()
|
||||
);
|
||||
|
||||
Ok(output_array)
|
||||
}
|
||||
}
|
||||
|
||||
impl FaceNetEmbedder for EmbeddingGenerator {
|
||||
fn run_model(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||
self.run_models(faces)
|
||||
}
|
||||
}
|
||||
|
||||
// Main trait implementation for backward compatibility
|
||||
impl crate::faceembed::FaceEmbedder for EmbeddingGenerator {
|
||||
fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
|
||||
// Need to create a mutable reference for the session
|
||||
// This is a workaround for the trait signature mismatch
|
||||
self.run_models(faces)
|
||||
}
|
||||
}
|
||||
1053
src/gui/app.rs
Normal file
1053
src/gui/app.rs
Normal file
File diff suppressed because it is too large
Load Diff
569
src/gui/bridge.rs
Normal file
569
src/gui/bridge.rs
Normal file
@@ -0,0 +1,569 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::errors;
|
||||
use crate::facedet::{FaceDetectionConfig, FaceDetector, retinaface};
|
||||
use crate::faceembed::facenet;
|
||||
use crate::gui::app::{ComparisonResult, DetectionResult, ExecutorType};
|
||||
use bounding_box::Aabb2;
|
||||
use bounding_box::roi::MultiRoi as _;
|
||||
use error_stack::ResultExt;
|
||||
use fast_image_resize::ResizeOptions;
|
||||
use ndarray::{Array1, Array2, Array3, Array4};
|
||||
use ndarray_image::ImageToNdarray;
|
||||
use ndarray_math::CosineSimilarity;
|
||||
use ndarray_resize::NdFir;
|
||||
|
||||
const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../../models/retinaface.mnn");
|
||||
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../../models/facenet.mnn");
|
||||
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../../models/retinaface.onnx");
|
||||
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../../models/facenet.onnx");
|
||||
|
||||
pub struct FaceDetectionBridge;
|
||||
|
||||
impl FaceDetectionBridge {
|
||||
pub async fn detect_faces(
|
||||
image_path: PathBuf,
|
||||
output_path: Option<PathBuf>,
|
||||
threshold: f32,
|
||||
nms_threshold: f32,
|
||||
executor_type: ExecutorType,
|
||||
) -> DetectionResult {
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
match Self::run_detection_internal(
|
||||
image_path.clone(),
|
||||
output_path,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
executor_type,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok((faces_count, processed_image)) => {
|
||||
let processing_time = start_time.elapsed().as_secs_f64();
|
||||
DetectionResult::Success {
|
||||
image_path,
|
||||
faces_count,
|
||||
processed_image,
|
||||
processing_time,
|
||||
}
|
||||
}
|
||||
Err(error) => DetectionResult::Error(error.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn compare_faces(
|
||||
image1_path: PathBuf,
|
||||
image2_path: PathBuf,
|
||||
threshold: f32,
|
||||
nms_threshold: f32,
|
||||
executor_type: ExecutorType,
|
||||
) -> ComparisonResult {
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
match Self::run_comparison_internal(
|
||||
image1_path,
|
||||
image2_path,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
executor_type,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok((
|
||||
image1_faces,
|
||||
image2_faces,
|
||||
image1_face_rois,
|
||||
image2_face_rois,
|
||||
best_similarity,
|
||||
)) => {
|
||||
let processing_time = start_time.elapsed().as_secs_f64();
|
||||
ComparisonResult::Success {
|
||||
image1_faces,
|
||||
image2_faces,
|
||||
image1_face_rois,
|
||||
image2_face_rois,
|
||||
best_similarity,
|
||||
processing_time,
|
||||
}
|
||||
}
|
||||
Err(error) => ComparisonResult::Error(error.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_detection_internal(
|
||||
image_path: PathBuf,
|
||||
output_path: Option<PathBuf>,
|
||||
threshold: f32,
|
||||
nms_threshold: f32,
|
||||
executor_type: ExecutorType,
|
||||
) -> Result<(usize, Option<Vec<u8>>), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Load the image
|
||||
let img = image::open(&image_path)?;
|
||||
let img_rgb = img.to_rgb8();
|
||||
|
||||
// Convert to ndarray format
|
||||
let image_array = img_rgb.as_ndarray()?;
|
||||
|
||||
// Create detection configuration
|
||||
let config = FaceDetectionConfig::default()
|
||||
.with_threshold(threshold)
|
||||
.with_nms_threshold(nms_threshold)
|
||||
.with_input_width(1024)
|
||||
.with_input_height(1024);
|
||||
|
||||
// Create detector and detect faces
|
||||
let faces = match executor_type {
|
||||
ExecutorType::MnnCpu => {
|
||||
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::CPU)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
detector
|
||||
.detect_faces(image_array.view(), &config)
|
||||
.map_err(|e| format!("Detection failed: {}", e))?
|
||||
}
|
||||
#[cfg(feature = "mnn-metal")]
|
||||
ExecutorType::MnnMetal => {
|
||||
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::Metal)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
detector
|
||||
.detect_faces(image_array.view(), &config)
|
||||
.map_err(|e| format!("Detection failed: {}", e))?
|
||||
}
|
||||
#[cfg(feature = "mnn-coreml")]
|
||||
ExecutorType::MnnCoreML => {
|
||||
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::CoreML)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
detector
|
||||
.detect_faces(image_array.view(), &config)
|
||||
.map_err(|e| format!("Detection failed: {}", e))?
|
||||
}
|
||||
ExecutorType::OnnxCpu => {
|
||||
let mut detector = retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||
.map_err(|e| format!("Failed to create ONNX detector: {}", e))?
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build ONNX detector: {}", e))?;
|
||||
|
||||
detector
|
||||
.detect_faces(image_array.view(), &config)
|
||||
.map_err(|e| format!("Detection failed: {}", e))?
|
||||
}
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
ExecutorType::OrtCuda => {
|
||||
use crate::ort_ep::ExecutionProvider;
|
||||
|
||||
let ep = ExecutionProvider::CUDA;
|
||||
let mut detector = retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||
.map_err(|e| format!("Failed to create ONNX CUDA detector: {}", e))?
|
||||
.with_execution_providers([ep])
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build ONNX CUDA detector: {}", e))?;
|
||||
|
||||
detector
|
||||
.detect_faces(image_array.view(), &config)
|
||||
.map_err(|e| format!("CUDA detection failed: {}", e))?
|
||||
}
|
||||
};
|
||||
|
||||
let faces_count = faces.bbox.len();
|
||||
|
||||
// Generate output image with bounding boxes if requested
|
||||
let processed_image = if output_path.is_some() || true {
|
||||
// Always generate for GUI display
|
||||
let mut output_img = img.to_rgb8();
|
||||
|
||||
for bbox in &faces.bbox {
|
||||
let min_point = bbox.min_vertex();
|
||||
let size = bbox.size();
|
||||
let rect = imageproc::rect::Rect::at(min_point.x as i32, min_point.y as i32)
|
||||
.of_size(size.x as u32, size.y as u32);
|
||||
imageproc::drawing::draw_hollow_rect_mut(
|
||||
&mut output_img,
|
||||
rect,
|
||||
image::Rgb([255, 0, 0]),
|
||||
);
|
||||
}
|
||||
|
||||
// Convert to bytes for GUI display
|
||||
let mut buffer = Vec::new();
|
||||
let mut cursor = std::io::Cursor::new(&mut buffer);
|
||||
image::DynamicImage::ImageRgb8(output_img.clone())
|
||||
.write_to(&mut cursor, image::ImageFormat::Png)?;
|
||||
|
||||
// Save to file if output path is specified
|
||||
if let Some(ref output_path) = output_path {
|
||||
output_img.save(output_path)?;
|
||||
}
|
||||
|
||||
Some(buffer)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok((faces_count, processed_image))
|
||||
}
|
||||
|
||||
async fn run_comparison_internal(
|
||||
image1_path: PathBuf,
|
||||
image2_path: PathBuf,
|
||||
threshold: f32,
|
||||
nms_threshold: f32,
|
||||
executor_type: ExecutorType,
|
||||
) -> Result<
|
||||
(usize, usize, Vec<Array3<u8>>, Vec<Array3<u8>>, f32),
|
||||
Box<dyn std::error::Error + Send + Sync>,
|
||||
> {
|
||||
// Create detector and embedder, detect faces and generate embeddings
|
||||
let (image1_faces, image2_faces, image1_rois, image2_rois, best_similarity) =
|
||||
match executor_type {
|
||||
ExecutorType::MnnCpu => {
|
||||
let mut detector =
|
||||
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::CPU)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::CPU)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
|
||||
|
||||
let img_1 = run_detection(
|
||||
image1_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
let img_2 = run_detection(
|
||||
image2_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
|
||||
let image1_rois = img_1.rois;
|
||||
let image2_rois = img_2.rois;
|
||||
let image1_bbox_len = img_1.bbox.len();
|
||||
let image2_bbox_len = img_2.bbox.len();
|
||||
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
|
||||
|
||||
(
|
||||
image1_bbox_len,
|
||||
image2_bbox_len,
|
||||
image1_rois,
|
||||
image2_rois,
|
||||
best_similarity,
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "mnn-metal")]
|
||||
ExecutorType::MnnMetal => {
|
||||
let mut detector =
|
||||
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::Metal)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::Metal)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
|
||||
|
||||
let img_1 = run_detection(
|
||||
image1_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
let img_2 = run_detection(
|
||||
image2_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
|
||||
let image1_rois = img_1.rois;
|
||||
let image2_rois = img_2.rois;
|
||||
let image1_bbox_len = img_1.bbox.len();
|
||||
let image2_bbox_len = img_2.bbox.len();
|
||||
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
|
||||
|
||||
(
|
||||
image1_bbox_len,
|
||||
image2_bbox_len,
|
||||
image1_rois,
|
||||
image2_rois,
|
||||
best_similarity,
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "mnn-coreml")]
|
||||
ExecutorType::MnnCoreML => {
|
||||
let mut detector =
|
||||
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::CoreML)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
|
||||
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
|
||||
.with_forward_type(mnn::ForwardType::CoreML)
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
|
||||
|
||||
let img_1 = run_detection(
|
||||
image1_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
let img_2 = run_detection(
|
||||
image2_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
|
||||
let image1_rois = img_1.rois;
|
||||
let image2_rois = img_2.rois;
|
||||
let image1_bbox_len = img_1.bbox.len();
|
||||
let image2_bbox_len = img_2.bbox.len();
|
||||
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
|
||||
|
||||
(
|
||||
image1_bbox_len,
|
||||
image2_bbox_len,
|
||||
image1_rois,
|
||||
image2_rois,
|
||||
best_similarity,
|
||||
)
|
||||
}
|
||||
ExecutorType::OnnxCpu => unimplemented!("ONNX face comparison not yet implemented"),
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
ExecutorType::OrtCuda => {
|
||||
use crate::ort_ep::ExecutionProvider;
|
||||
let ep = ExecutionProvider::CUDA;
|
||||
let mut detector =
|
||||
retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
|
||||
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
|
||||
.with_execution_providers([ep])
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
|
||||
|
||||
let mut embedder =
|
||||
facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
|
||||
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
|
||||
.with_execution_providers([ep])
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
|
||||
|
||||
let img_1 = run_detection(
|
||||
image1_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
let img_2 = run_detection(
|
||||
image2_path,
|
||||
&mut detector,
|
||||
&mut embedder,
|
||||
threshold,
|
||||
nms_threshold,
|
||||
2,
|
||||
)?;
|
||||
|
||||
let image1_rois = img_1.rois;
|
||||
let image2_rois = img_2.rois;
|
||||
let image1_bbox_len = img_1.bbox.len();
|
||||
let image2_bbox_len = img_2.bbox.len();
|
||||
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
|
||||
|
||||
(
|
||||
image1_bbox_len,
|
||||
image2_bbox_len,
|
||||
image1_rois,
|
||||
image2_rois,
|
||||
best_similarity,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
Ok((
|
||||
image1_faces,
|
||||
image2_faces,
|
||||
image1_rois,
|
||||
image2_rois,
|
||||
best_similarity,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
use crate::errors::Error;
|
||||
pub fn compare_faces(
|
||||
faces_1: &[Array1<f32>],
|
||||
faces_2: &[Array1<f32>],
|
||||
) -> Result<f32, error_stack::Report<crate::errors::Error>> {
|
||||
use error_stack::Report;
|
||||
|
||||
if faces_1.is_empty() || faces_2.is_empty() {
|
||||
Err(Report::new(crate::errors::Error))
|
||||
.attach_printable("One or both images have no detected faces")?;
|
||||
}
|
||||
if faces_1.len() != faces_2.len() {
|
||||
Err(Report::new(crate::errors::Error))
|
||||
.attach_printable("Face count mismatch between images")?;
|
||||
}
|
||||
Ok(faces_1
|
||||
.iter()
|
||||
.zip(faces_2)
|
||||
.flat_map(|(face_1, face_2)| face_1.cosine_similarity(face_2))
|
||||
.inspect(|v| tracing::info!("Cosine similarity: {}", v))
|
||||
.map(|v| ordered_float::OrderedFloat(v))
|
||||
.max()
|
||||
.map(|v| v.0)
|
||||
.ok_or(Report::new(Error))?)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DetectionOutput {
|
||||
bbox: Vec<Aabb2<usize>>,
|
||||
rois: Vec<ndarray::Array3<u8>>,
|
||||
embeddings: Vec<Array1<f32>>,
|
||||
}
|
||||
|
||||
fn run_detection<D, E>(
|
||||
image: impl AsRef<std::path::Path>,
|
||||
retinaface: &mut D,
|
||||
facenet: &mut E,
|
||||
threshold: f32,
|
||||
nms_threshold: f32,
|
||||
chunk_size: usize,
|
||||
) -> crate::errors::Result<DetectionOutput>
|
||||
where
|
||||
D: crate::facedet::FaceDetector,
|
||||
E: crate::faceembed::FaceEmbedder,
|
||||
{
|
||||
use errors::*;
|
||||
// Initialize database if requested
|
||||
let image = image.as_ref();
|
||||
let image = image::open(image)
|
||||
.change_context(Error)
|
||||
.attach_printable(image.to_string_lossy().to_string())?;
|
||||
let image = image.into_rgb8();
|
||||
let mut array = image
|
||||
.into_ndarray()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to convert image to ndarray")?;
|
||||
let output = retinaface
|
||||
.detect_faces(
|
||||
array.view(),
|
||||
&FaceDetectionConfig::default()
|
||||
.with_threshold(threshold)
|
||||
.with_nms_threshold(nms_threshold),
|
||||
)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
|
||||
let bboxes = output
|
||||
.bbox
|
||||
.iter()
|
||||
.inspect(|bbox| tracing::info!("Raw bbox: {:?}", bbox))
|
||||
.map(|bbox| bbox.as_::<f32>().scale_uniform(1.30).as_::<usize>())
|
||||
.inspect(|bbox| tracing::info!("Padded bbox: {:?}", bbox))
|
||||
.collect_vec();
|
||||
for bbox in &bboxes {
|
||||
tracing::info!("Detected face: {:?}", bbox);
|
||||
use bounding_box::draw::*;
|
||||
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
||||
}
|
||||
use itertools::Itertools;
|
||||
let face_rois = array
|
||||
.view()
|
||||
.multi_roi(&bboxes)
|
||||
.change_context(Error)?
|
||||
.into_iter()
|
||||
.map(|roi| {
|
||||
roi.as_standard_layout()
|
||||
.fast_resize(224, 224, &ResizeOptions::default())
|
||||
.change_context(Error)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
|
||||
|
||||
let embeddings: Vec<Array1<f32>> = face_roi_views
|
||||
.chunks(chunk_size)
|
||||
.map(|chunk| {
|
||||
tracing::info!("Processing chunk of size: {}", chunk.len());
|
||||
|
||||
let og_size = chunk.len();
|
||||
if chunk.len() < chunk_size {
|
||||
tracing::warn!("Chunk size is less than 8, padding with zeros");
|
||||
let zeros = Array3::zeros((224, 224, 3));
|
||||
let chunk: Vec<_> = chunk
|
||||
.iter()
|
||||
.map(|arr| arr.reborrow())
|
||||
.chain(core::iter::repeat(zeros.view()))
|
||||
.take(chunk_size)
|
||||
.collect();
|
||||
let face_rois: Array4<u8> = ndarray::stack(ndarray::Axis(0), chunk.as_slice())
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
Ok((output, og_size))
|
||||
} else {
|
||||
let face_rois: Array4<u8> = ndarray::stack(ndarray::Axis(0), chunk)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to stack rois together")?;
|
||||
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
|
||||
Ok((output, og_size))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<(Array2<f32>, usize)>>>()?
|
||||
.into_iter()
|
||||
.map(|(chunk, size): (Array2<f32>, usize)| {
|
||||
use itertools::Itertools;
|
||||
chunk
|
||||
.rows()
|
||||
.into_iter()
|
||||
.take(size)
|
||||
.map(|row| row.to_owned())
|
||||
.collect_vec()
|
||||
.into_iter()
|
||||
})
|
||||
.flatten()
|
||||
.collect::<Vec<Array1<f32>>>();
|
||||
|
||||
Ok(DetectionOutput {
|
||||
bbox: bboxes,
|
||||
rois: face_rois,
|
||||
embeddings,
|
||||
})
|
||||
}
|
||||
5
src/gui/mod.rs
Normal file
5
src/gui/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod app;
|
||||
pub mod bridge;
|
||||
|
||||
pub use app::{FaceDetectorApp, Message, run};
|
||||
pub use bridge::FaceDetectionBridge;
|
||||
@@ -1,5 +0,0 @@
|
||||
// pub struct Image {
|
||||
// pub width: u32,
|
||||
// pub height: u32,
|
||||
// pub data: Vec<u8>,
|
||||
// }
|
||||
@@ -1,4 +1,7 @@
|
||||
pub mod database;
|
||||
pub mod errors;
|
||||
pub mod facedet;
|
||||
pub mod image;
|
||||
use errors::*;
|
||||
pub mod faceembed;
|
||||
pub mod gui;
|
||||
pub mod ort_ep;
|
||||
pub use errors::*;
|
||||
|
||||
59
src/main.rs
59
src/main.rs
@@ -1,59 +0,0 @@
|
||||
mod cli;
|
||||
mod errors;
|
||||
use detector::facedet::retinaface::FaceDetectionConfig;
|
||||
use errors::*;
|
||||
use ndarray_image::*;
|
||||
const RETINAFACE_MODEL: &[u8] = include_bytes!("../models/retinaface.mnn");
|
||||
pub fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("trace")
|
||||
.with_thread_ids(true)
|
||||
.with_thread_names(true)
|
||||
.with_target(false)
|
||||
.init();
|
||||
let args = <cli::Cli as clap::Parser>::parse();
|
||||
match args.cmd {
|
||||
cli::SubCommand::Detect(detect) => {
|
||||
use detector::facedet;
|
||||
let model = facedet::retinaface::FaceDetection::new_from_bytes(RETINAFACE_MODEL)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to create face detection model")?;
|
||||
let image = image::open(detect.image).change_context(Error)?;
|
||||
let image = image.into_rgb8();
|
||||
let mut array = image
|
||||
.into_ndarray()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to convert image to ndarray")?;
|
||||
let output = model
|
||||
.detect_faces(
|
||||
array.clone(),
|
||||
FaceDetectionConfig::default().with_threshold(detect.threshold),
|
||||
)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to detect faces")?;
|
||||
for bbox in output.bbox {
|
||||
tracing::info!("Detected face: {:?}", bbox);
|
||||
use bounding_box::draw::*;
|
||||
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
|
||||
}
|
||||
let v = array.view();
|
||||
if let Some(output) = detect.output {
|
||||
let image: image::RgbImage = v
|
||||
.to_image()
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to convert ndarray to image")?;
|
||||
image
|
||||
.save(output)
|
||||
.change_context(errors::Error)
|
||||
.attach_printable("Failed to save output image")?;
|
||||
}
|
||||
}
|
||||
cli::SubCommand::List(list) => {
|
||||
println!("List: {:?}", list);
|
||||
}
|
||||
cli::SubCommand::Completions { shell } => {
|
||||
cli::Cli::completions(shell);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
197
src/ort_ep.rs
Normal file
197
src/ort_ep.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
use ort::execution_providers::CUDAExecutionProvider;
|
||||
#[cfg(feature = "ort-coreml")]
|
||||
use ort::execution_providers::CoreMLExecutionProvider;
|
||||
#[cfg(feature = "ort-directml")]
|
||||
use ort::execution_providers::DirectMLExecutionProvider;
|
||||
#[cfg(feature = "ort-openvino")]
|
||||
use ort::execution_providers::OpenVINOExecutionProvider;
|
||||
#[cfg(feature = "ort-tvm")]
|
||||
use ort::execution_providers::TVMExecutionProvider;
|
||||
#[cfg(feature = "ort-tensorrt")]
|
||||
use ort::execution_providers::TensorRTExecutionProvider;
|
||||
use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch};
|
||||
|
||||
/// Supported execution providers for ONNX Runtime
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub enum ExecutionProvider {
|
||||
/// CPU execution provider (always available)
|
||||
CPU,
|
||||
/// CoreML execution provider (macOS only)
|
||||
CoreML,
|
||||
/// CUDA execution provider (requires cuda feature)
|
||||
CUDA,
|
||||
/// TensorRT execution provider (requires tensorrt feature)
|
||||
TensorRT,
|
||||
/// TVM execution provider (requires tvm feature)
|
||||
TVM,
|
||||
/// OpenVINO execution provider (requires openvino feature)
|
||||
OpenVINO,
|
||||
/// DirectML execution provider (Windows only, requires directml feature)
|
||||
DirectML,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ExecutionProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ExecutionProvider::CPU => write!(f, "CPU"),
|
||||
ExecutionProvider::CoreML => write!(f, "CoreML"),
|
||||
ExecutionProvider::CUDA => write!(f, "CUDA"),
|
||||
ExecutionProvider::TensorRT => write!(f, "TensorRT"),
|
||||
ExecutionProvider::TVM => write!(f, "TVM"),
|
||||
ExecutionProvider::OpenVINO => write!(f, "OpenVINO"),
|
||||
ExecutionProvider::DirectML => write!(f, "DirectML"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for ExecutionProvider {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"cpu" => Ok(ExecutionProvider::CPU),
|
||||
"coreml" => Ok(ExecutionProvider::CoreML),
|
||||
"cuda" => Ok(ExecutionProvider::CUDA),
|
||||
"tensorrt" => Ok(ExecutionProvider::TensorRT),
|
||||
"tvm" => Ok(ExecutionProvider::TVM),
|
||||
"openvino" => Ok(ExecutionProvider::OpenVINO),
|
||||
"directml" => Ok(ExecutionProvider::DirectML),
|
||||
_ => Err(format!("Unknown execution provider: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider {
|
||||
/// Returns all available execution providers for the current platform and features
|
||||
pub fn available_providers() -> Vec<ExecutionProvider> {
|
||||
vec![
|
||||
ExecutionProvider::CPU,
|
||||
#[cfg(all(target_os = "macos", feature = "ort-coreml"))]
|
||||
ExecutionProvider::CoreML,
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
ExecutionProvider::CUDA,
|
||||
#[cfg(feature = "ort-tensorrt")]
|
||||
ExecutionProvider::TensorRT,
|
||||
#[cfg(feature = "ort-tvm")]
|
||||
ExecutionProvider::TVM,
|
||||
#[cfg(feature = "ort-openvino")]
|
||||
ExecutionProvider::OpenVINO,
|
||||
#[cfg(all(target_os = "windows", feature = "ort-directml"))]
|
||||
ExecutionProvider::DirectML,
|
||||
]
|
||||
}
|
||||
|
||||
/// Check if this execution provider is available on the current platform
|
||||
pub fn is_available(&self) -> bool {
|
||||
match self {
|
||||
ExecutionProvider::CPU => true,
|
||||
ExecutionProvider::CoreML => cfg!(target_os = "macos") && cfg!(feature = "ort-coreml"),
|
||||
ExecutionProvider::CUDA => cfg!(feature = "ort-cuda"),
|
||||
ExecutionProvider::TensorRT => cfg!(feature = "ort-tensorrt"),
|
||||
ExecutionProvider::TVM => cfg!(feature = "ort-tvm"),
|
||||
ExecutionProvider::OpenVINO => cfg!(feature = "ort-openvino"),
|
||||
ExecutionProvider::DirectML => {
|
||||
cfg!(target_os = "windows") && cfg!(feature = "ort-directml")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionProvider {
|
||||
pub fn to_dispatch(&self) -> Option<ExecutionProviderDispatch> {
|
||||
match self {
|
||||
ExecutionProvider::CPU => Some(CPUExecutionProvider::default().build()),
|
||||
ExecutionProvider::CoreML => {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
#[cfg(feature = "ort-coreml")]
|
||||
{
|
||||
use tap::Tap;
|
||||
|
||||
Some(
|
||||
CoreMLExecutionProvider::default()
|
||||
.with_model_format(
|
||||
ort::execution_providers::coreml::CoreMLModelFormat::MLProgram,
|
||||
)
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
#[cfg(not(feature = "ort-coreml"))]
|
||||
{
|
||||
tracing::error!("coreml support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
tracing::error!("CoreML is only available on macOS");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::CUDA => {
|
||||
#[cfg(feature = "ort-cuda")]
|
||||
{
|
||||
Some(CUDAExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-cuda"))]
|
||||
{
|
||||
tracing::error!("CUDA support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::TensorRT => {
|
||||
#[cfg(feature = "ort-tensorrt")]
|
||||
{
|
||||
Some(TensorRTExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-tensorrt"))]
|
||||
{
|
||||
tracing::error!("TensorRT support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::TVM => {
|
||||
#[cfg(feature = "ort-tvm")]
|
||||
{
|
||||
Some(TVMExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-tvm"))]
|
||||
{
|
||||
tracing::error!("TVM support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::OpenVINO => {
|
||||
#[cfg(feature = "ort-openvino")]
|
||||
{
|
||||
Some(OpenVINOExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-openvino"))]
|
||||
{
|
||||
tracing::error!("OpenVINO support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
ExecutionProvider::DirectML => {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
#[cfg(feature = "ort-directml")]
|
||||
{
|
||||
Some(DirectMLExecutionProvider::default().build())
|
||||
}
|
||||
#[cfg(not(feature = "ort-directml"))]
|
||||
{
|
||||
tracing::error!("DirectML support not compiled in");
|
||||
None
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
tracing::error!("DirectML is only available on Windows");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user