Compare commits

...

22 Commits

Author SHA1 Message Date
uttarayan21
59a3fddc0b chore: delete unused files and outdated GUI_DEMO documentation
Some checks failed
build / checks-matrix (push) Successful in 19m22s
build / codecov (push) Failing after 19m22s
docs / docs (push) Failing after 28m48s
build / checks-build (push) Has been cancelled
2025-09-23 16:13:56 +05:30
uttarayan21
eb9451aad8 chore: remove submodule 'rfcs' from the project
Some checks failed
build / checks-matrix (push) Successful in 19m24s
build / codecov (push) Failing after 19m23s
docs / docs (push) Has been cancelled
build / checks-build (push) Has been cancelled
2025-09-23 15:08:54 +05:30
uttarayan21
c6b3f5279f feat(flake): add uv package to build inputs
Some checks failed
build / checks-matrix (push) Successful in 19m21s
build / codecov (push) Failing after 19m25s
docs / docs (push) Failing after 28m47s
build / checks-build (push) Has been cancelled
2025-09-16 12:28:09 +05:30
uttarayan21
a419a5ac4a chore(models): remove Facenet and RetinaFace model files
Some checks failed
build / checks-matrix (push) Has been cancelled
build / checks-build (push) Has been cancelled
build / codecov (push) Has been cancelled
docs / docs (push) Has been cancelled
2025-09-16 12:22:38 +05:30
uttarayan21
a340552257 feat(cli): add clustering command with K-means support
Some checks failed
build / checks-matrix (push) Successful in 19m25s
build / codecov (push) Failing after 19m26s
docs / docs (push) Failing after 28m52s
build / checks-build (push) Has been cancelled
2025-09-13 17:45:55 +05:30
uttarayan21
aaf34ef74e refactor: rename sqlite3-safetensor-cosine to sqlite3-ndarray-math
Some checks failed
build / checks-matrix (push) Successful in 19m20s
build / codecov (push) Failing after 19m22s
docs / docs (push) Failing after 28m47s
build / checks-build (push) Has been cancelled
2025-08-28 18:42:35 +05:30
uttarayan21
ac8f1d01b4 feat(detector): add CUDA support for ONNX face detection
Some checks failed
build / checks-build (push) Has been cancelled
build / codecov (push) Has been cancelled
docs / docs (push) Has been cancelled
build / checks-matrix (push) Has been cancelled
2025-08-28 18:32:00 +05:30
uttarayan21
4256c0af74 feat(makefile): add conversion task and update model binaries
Some checks failed
build / checks-matrix (push) Successful in 19m22s
build / codecov (push) Failing after 19m22s
docs / docs (push) Failing after 28m50s
build / checks-build (push) Has been cancelled
2025-08-28 13:43:23 +05:30
uttarayan21
3eec262076 feat(bounding-box): add scale_uniform method for consistent scaling
Some checks failed
build / checks-matrix (push) Successful in 19m22s
build / codecov (push) Failing after 19m26s
docs / docs (push) Failing after 28m51s
build / checks-build (push) Has been cancelled
feat(gui): display face ROIs in comparison results

refactor(bridge): pad detected face bounding boxes uniformly
2025-08-22 19:01:34 +05:30
uttarayan21
c758fd8d41 feat(gui): add face ROIs to comparison results and update image size 2025-08-22 18:26:29 +05:30
uttarayan21
34eaf9348a refactor(gui): remove commented-out code in face detection function
Some checks failed
build / checks-matrix (push) Successful in 19m20s
build / codecov (push) Failing after 19m18s
docs / docs (push) Has been cancelled
build / checks-build (push) Has been cancelled
2025-08-22 18:15:55 +05:30
uttarayan21
dab7719206 refactor: replace bbox::BBox with bounding_box::Aabb2 across codebase
Some checks failed
build / checks-matrix (push) Has been cancelled
build / checks-build (push) Has been cancelled
build / codecov (push) Has been cancelled
docs / docs (push) Has been cancelled
2025-08-22 18:14:58 +05:30
uttarayan21
4b4d23d1d4 feat(bbox): add bounding box implementation with serialization
Add initial implementation of the `BBox` struct in the `bbox` module,
including basic operations and serialization/deserialization support
with Serde.
2025-08-22 15:27:47 +05:30
uttarayan21
aab3d84db0 feat(ndcv-bridge): add ndcv-bridge for ndarray and opencv interaction 2025-08-22 15:27:36 +05:30
uttarayan21
65560825fa feat: add cargo-outdated and improve slider precision in app views
Some checks failed
build / checks-matrix (push) Successful in 19m24s
build / codecov (push) Failing after 19m27s
docs / docs (push) Failing after 28m47s
build / checks-build (push) Has been cancelled
2025-08-22 13:06:16 +05:30
uttarayan21
0a5dbaaadc refactor(gui): set fixed input dimensions for face detection 2025-08-21 18:52:58 +05:30
uttarayan21
3e14a16739 feat(gui): Added iced gui 2025-08-21 18:28:39 +05:30
uttarayan21
bfa389b497 feat(compare): add face comparison functionality with cosine similarity
Some checks failed
build / checks-matrix (push) Successful in 19m23s
build / codecov (push) Failing after 19m18s
docs / docs (push) Failing after 28m50s
build / checks-build (push) Has been cancelled
2025-08-21 17:34:07 +05:30
uttarayan21
f8122892e0 feat(ndarray-safetensors): add tensor_by_index method for SafeArraysView
Some checks failed
build / checks-matrix (push) Successful in 19m24s
build / codecov (push) Failing after 19m27s
docs / docs (push) Failing after 28m51s
build / checks-build (push) Has been cancelled
2025-08-20 16:05:18 +05:30
uttarayan21
97f64e7e10 feat: save safetensors to the database
Some checks failed
build / checks-matrix (push) Successful in 19m23s
build / codecov (push) Failing after 19m26s
docs / docs (push) Failing after 28m47s
build / checks-build (push) Has been cancelled
2025-08-20 12:17:18 +05:30
uttarayan21
37adb74adf feat: Save tensors to database as safetensor 2025-08-20 12:17:18 +05:30
uttarayan21
47218fa696 feat: Added ndarray-safetensors 2025-08-20 12:17:16 +05:30
64 changed files with 12887 additions and 380 deletions

4400
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,27 @@
[workspace] [workspace]
members = ["ndarray-image", "ndarray-resize", ".", "bounding-box"] members = [
"ndarray-image",
"ndarray-resize",
".",
"bounding-box",
"ndarray-safetensors",
"sqlite3-ndarray-math",
"ndcv-bridge",
"bbox",
]
[workspace.package] [workspace.package]
version = "0.1.0" version = "0.1.0"
edition = "2024" edition = "2024"
[patch.crates-io]
linfa = { git = "https://github.com/relf/linfa", branch = "upgrade-ndarray-0.16" }
linfa-clustering = { git = "https://github.com/relf/linfa", branch = "upgrade-ndarray-0.16" }
[workspace.dependencies] [workspace.dependencies]
divan = { version = "0.1.21" }
ndarray-npy = "0.9.1"
serde = { version = "1.0", features = ["derive"] }
ndarray-image = { path = "ndarray-image" } ndarray-image = { path = "ndarray-image" }
ndarray-resize = { path = "ndarray-resize" } ndarray-resize = { path = "ndarray-resize" }
mnn = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.2.0", features = [ mnn = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.2.0", features = [
@@ -20,6 +36,15 @@ mnn-sync = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", f
"tracing", "tracing",
], branch = "restructure-tensor-type" } ], branch = "restructure-tensor-type" }
nalgebra = { version = "0.34.0", default-features = false, features = ["std"] } nalgebra = { version = "0.34.0", default-features = false, features = ["std"] }
opencv = { version = "0.95.1" }
bounding-box = { path = "bounding-box" }
bytemuck = "1.23.2"
error-stack = "0.5.0"
thiserror = "2.0"
fast_image_resize = "5.2.0"
img-parts = "0.4.0"
ndarray = { version = "0.16.1", features = ["rayon"] }
num = "0.4"
[package] [package]
name = "detector" name = "detector"
@@ -37,7 +62,7 @@ nalgebra = { workspace = true }
ndarray = "0.16.1" ndarray = "0.16.1"
ndarray-image = { workspace = true } ndarray-image = { workspace = true }
ndarray-resize = { workspace = true } ndarray-resize = { workspace = true }
rusqlite = { version = "0.37.0", features = ["modern-full"] } rusqlite = { version = "0.37.0", features = ["functions", "modern-full"] }
tap = "1.0.1" tap = "1.0.1"
thiserror = "2.0" thiserror = "2.0"
tokio = "1.43.1" tokio = "1.43.1"
@@ -54,13 +79,25 @@ ort = { version = "2.0.0-rc.10", default-features = false, features = [
"std", "std",
"tracing", "tracing",
"ndarray", "ndarray",
"cuda",
] } ] }
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
ndarray-safetensors = { version = "0.1.0", path = "ndarray-safetensors" }
sqlite3-ndarray-math = { version = "0.1.0", path = "sqlite3-ndarray-math" }
# GUI dependencies
iced = { version = "0.13", features = ["tokio", "image"] }
rfd = "0.15"
futures = "0.3"
imageproc = "0.25"
linfa = "0.7.1"
linfa-clustering = "0.7.1"
[profile.release] [profile.release]
debug = true debug = true
[features] [features]
ort-cuda = ["ort/cuda"] ort-cuda = []
ort-coreml = ["ort/coreml"] ort-coreml = ["ort/coreml"]
ort-tensorrt = ["ort/tensorrt"] ort-tensorrt = ["ort/tensorrt"]
ort-tvm = ["ort/tvm"] ort-tvm = ["ort/tvm"]
@@ -69,4 +106,8 @@ ort-directml = ["ort/directml"]
mnn-metal = ["mnn/metal"] mnn-metal = ["mnn/metal"]
mnn-coreml = ["mnn/coreml"] mnn-coreml = ["mnn/coreml"]
default = [] default = ["ort-cuda"]
[[test]]
name = "test_bbox_replacement"
path = "test_bbox_replacement.rs"

View File

@@ -1,3 +1,7 @@
[tasks.convert]
dependencies = ["convert_facenet", "convert_retinaface"]
workspace = false
[tasks.convert_facenet] [tasks.convert_facenet]
command = "MNNConvert" command = "MNNConvert"
args = [ args = [
@@ -11,6 +15,7 @@ args = [
"--bizCode", "--bizCode",
"MNN", "MNN",
] ]
workspace = false
[tasks.convert_retinaface] [tasks.convert_retinaface]
command = "MNNConvert" command = "MNNConvert"
@@ -25,3 +30,9 @@ args = [
"--bizCode", "--bizCode",
"MNN", "MNN",
] ]
workspace = false
[tasks.gui]
command = "cargo"
args = ["run", "--release", "--bin", "gui"]
workspace = false

View File

@@ -55,6 +55,35 @@ cargo run --release detect --output detected.jpg path/to/image.jpg
cargo run --release detect --threshold 0.9 --nms-threshold 0.4 path/to/image.jpg cargo run --release detect --threshold 0.9 --nms-threshold 0.4 path/to/image.jpg
``` ```
### Face Comparison
Compare faces between two images by computing and comparing their embeddings:
```bash
# Compare faces in two images
cargo run --release compare image1.jpg image2.jpg
# Compare with custom thresholds
cargo run --release compare --threshold 0.9 --nms-threshold 0.4 image1.jpg image2.jpg
# Use ONNX Runtime backend for comparison
cargo run --release compare -p cpu image1.jpg image2.jpg
# Use MNN with Metal acceleration
cargo run --release compare -f metal image1.jpg image2.jpg
```
The compare command will:
1. Detect all faces in both images
2. Generate embeddings for each detected face
3. Compute cosine similarity between all face pairs
4. Display similarity scores and the best match
5. Provide interpretation of the similarity scores:
- **> 0.8**: Very likely the same person
- **0.6-0.8**: Possibly the same person
- **0.4-0.6**: Unlikely to be the same person
- **< 0.4**: Very unlikely to be the same person
### Backend Selection ### Backend Selection
The project supports two inference backends: The project supports two inference backends:

1
assets/headshots Symbolic link
View File

@@ -0,0 +1 @@
/Users/fs0c131y/Pictures/test_cases/compressed/HeadshotJpeg

13
bbox/Cargo.toml Normal file
View File

@@ -0,0 +1,13 @@
[package]
name = "bbox"
version.workspace = true
edition.workspace = true
[dependencies]
ndarray.workspace = true
num = "0.4.3"
serde = { workspace = true, features = ["derive"], optional = true }
[features]
serde = ["dep:serde"]
default = ["serde"]

708
bbox/src/lib.rs Normal file
View File

@@ -0,0 +1,708 @@
pub mod traits;
/// A bounding box of co-ordinates whose origin is at the top-left corner.
#[derive(
Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Hash, serde::Serialize, serde::Deserialize,
)]
#[non_exhaustive]
pub struct BBox<T = f32> {
pub x: T,
pub y: T,
pub width: T,
pub height: T,
}
impl<T> From<[T; 4]> for BBox<T> {
fn from([x, y, width, height]: [T; 4]) -> Self {
Self {
x,
y,
width,
height,
}
}
}
impl<T: Copy> BBox<T> {
pub fn new(x: T, y: T, width: T, height: T) -> Self {
Self {
x,
y,
width,
height,
}
}
/// Casts the internal values to another type using [as] keyword
pub fn cast<T2>(self) -> BBox<T2>
where
T: num::cast::AsPrimitive<T2>,
T2: Copy + 'static,
{
BBox {
x: self.x.as_(),
y: self.y.as_(),
width: self.width.as_(),
height: self.height.as_(),
}
}
/// Clamps all the internal values to the given min and max.
pub fn clamp(&self, min: T, max: T) -> Self
where
T: std::cmp::PartialOrd,
{
Self {
x: num::clamp(self.x, min, max),
y: num::clamp(self.y, min, max),
width: num::clamp(self.width, min, max),
height: num::clamp(self.height, min, max),
}
}
pub fn clamp_box(&self, bbox: BBox<T>) -> Self
where
T: std::cmp::PartialOrd,
T: num::Zero,
T: core::ops::Add<Output = T>,
T: core::ops::Sub<Output = T>,
{
let x1 = num::clamp(self.x1(), bbox.x1(), bbox.x2());
let y1 = num::clamp(self.y1(), bbox.y1(), bbox.y2());
let x2 = num::clamp(self.x2(), bbox.x1(), bbox.x2());
let y2 = num::clamp(self.y2(), bbox.y1(), bbox.y2());
Self::new_xyxy(x1, y1, x2, y2)
}
pub fn normalize(&self, width: T, height: T) -> Self
where
T: core::ops::Div<Output = T> + Copy,
{
Self {
x: self.x / width,
y: self.y / height,
width: self.width / width,
height: self.height / height,
}
}
/// Normalize after casting to float
pub fn normalize_f64(&self, width: T, height: T) -> BBox<f64>
where
T: core::ops::Div<Output = T> + Copy,
T: num::cast::AsPrimitive<f64>,
{
BBox {
x: self.x.as_() / width.as_(),
y: self.y.as_() / height.as_(),
width: self.width.as_() / width.as_(),
height: self.height.as_() / height.as_(),
}
}
pub fn denormalize(&self, width: T, height: T) -> Self
where
T: core::ops::Mul<Output = T> + Copy,
{
Self {
x: self.x * width,
y: self.y * height,
width: self.width * width,
height: self.height * height,
}
}
pub fn height(&self) -> T {
self.height
}
pub fn width(&self) -> T {
self.width
}
pub fn padding(&self, padding: T) -> Self
where
T: core::ops::Add<Output = T> + core::ops::Sub<Output = T> + Copy,
{
Self {
x: self.x - padding,
y: self.y - padding,
width: self.width + padding + padding,
height: self.height + padding + padding,
}
}
pub fn padding_height(&self, padding: T) -> Self
where
T: core::ops::Add<Output = T> + core::ops::Sub<Output = T> + Copy,
{
Self {
x: self.x,
y: self.y - padding,
width: self.width,
height: self.height + padding + padding,
}
}
pub fn padding_width(&self, padding: T) -> Self
where
T: core::ops::Add<Output = T> + core::ops::Sub<Output = T> + Copy,
{
Self {
x: self.x - padding,
y: self.y,
width: self.width + padding + padding,
height: self.height,
}
}
// Enlarge / shrink the bounding box by a factor while
// keeping the center point and the aspect ratio fixed
pub fn scale(&self, factor: T) -> Self
where
T: core::ops::Mul<Output = T>,
T: core::ops::Sub<Output = T>,
T: core::ops::Add<Output = T>,
T: core::ops::Div<Output = T>,
T: num::One + Copy,
{
let two = num::one::<T>() + num::one::<T>();
let width = self.width * factor;
let height = self.height * factor;
let width_inc = width - self.width;
let height_inc = height - self.height;
Self {
x: self.x - width_inc / two,
y: self.y - height_inc / two,
width,
height,
}
}
pub fn scale_x(&self, factor: T) -> Self
where
T: core::ops::Mul<Output = T>
+ core::ops::Sub<Output = T>
+ core::ops::Add<Output = T>
+ core::ops::Div<Output = T>
+ num::One
+ Copy,
{
let two = num::one::<T>() + num::one::<T>();
let width = self.width * factor;
let width_inc = width - self.width;
Self {
x: self.x - width_inc / two,
y: self.y,
width,
height: self.height,
}
}
pub fn scale_y(&self, factor: T) -> Self
where
T: core::ops::Mul<Output = T>
+ core::ops::Sub<Output = T>
+ core::ops::Add<Output = T>
+ core::ops::Div<Output = T>
+ num::One
+ Copy,
{
let two = num::one::<T>() + num::one::<T>();
let height = self.height * factor;
let height_inc = height - self.height;
Self {
x: self.x,
y: self.y - height_inc / two,
width: self.width,
height,
}
}
pub fn offset(&self, offset: Point<T>) -> Self
where
T: core::ops::Add<Output = T> + Copy,
{
Self {
x: self.x + offset.x,
y: self.y + offset.y,
width: self.width,
height: self.height,
}
}
/// Translate the bounding box by the given offset
/// if they are in the same scale
pub fn translate(&self, bbox: Self) -> Self
where
T: core::ops::Add<Output = T> + Copy,
{
Self {
x: self.x + bbox.x,
y: self.y + bbox.y,
width: self.width,
height: self.height,
}
}
pub fn with_top_left(&self, top_left: Point<T>) -> Self {
Self {
x: top_left.x,
y: top_left.y,
width: self.width,
height: self.height,
}
}
pub fn center(&self) -> Point<T>
where
T: core::ops::Add<Output = T> + core::ops::Div<Output = T> + Copy,
T: num::One,
{
let two = T::one() + T::one();
Point::new(self.x + self.width / two, self.y + self.height / two)
}
pub fn area(&self) -> T
where
T: core::ops::Mul<Output = T> + Copy,
{
self.width * self.height
}
// Corresponds to self.x1() and self.y1()
pub fn top_left(&self) -> Point<T> {
Point::new(self.x, self.y)
}
pub fn top_right(&self) -> Point<T>
where
T: core::ops::Add<Output = T> + Copy,
{
Point::new(self.x + self.width, self.y)
}
pub fn bottom_left(&self) -> Point<T>
where
T: core::ops::Add<Output = T> + Copy,
{
Point::new(self.x, self.y + self.height)
}
// Corresponds to self.x2() and self.y2()
pub fn bottom_right(&self) -> Point<T>
where
T: core::ops::Add<Output = T> + Copy,
{
Point::new(self.x + self.width, self.y + self.height)
}
pub const fn x1(&self) -> T {
self.x
}
pub const fn y1(&self) -> T {
self.y
}
pub fn x2(&self) -> T
where
T: core::ops::Add<Output = T> + Copy,
{
self.x + self.width
}
pub fn y2(&self) -> T
where
T: core::ops::Add<Output = T> + Copy,
{
self.y + self.height
}
pub fn overlap(&self, other: &Self) -> T
where
T: std::cmp::PartialOrd
+ traits::min::Min
+ traits::max::Max
+ num::Zero
+ core::ops::Add<Output = T>
+ core::ops::Sub<Output = T>
+ core::ops::Mul<Output = T>
+ Copy,
{
let x1 = self.x.max(other.x);
let y1 = self.y.max(other.y);
let x2 = (self.x + self.width).min(other.x + other.width);
let y2 = (self.y + self.height).min(other.y + other.height);
let width = (x2 - x1).max(T::zero());
let height = (y2 - y1).max(T::zero());
width * height
}
pub fn iou(&self, other: &Self) -> T
where
T: std::cmp::Ord
+ num::Zero
+ traits::min::Min
+ traits::max::Max
+ core::ops::Add<Output = T>
+ core::ops::Sub<Output = T>
+ core::ops::Mul<Output = T>
+ core::ops::Div<Output = T>
+ Copy,
{
let overlap = self.overlap(other);
let union = self.area() + other.area() - overlap;
overlap / union
}
pub fn contains(&self, point: Point<T>) -> bool
where
T: std::cmp::PartialOrd + core::ops::Add<Output = T> + Copy,
{
point.x >= self.x
&& point.x <= self.x + self.width
&& point.y >= self.y
&& point.y <= self.y + self.height
}
pub fn contains_bbox(&self, other: Self) -> bool
where
T: std::cmp::PartialOrd + Copy,
T: core::ops::Add<Output = T>,
{
self.contains(other.top_left())
&& self.contains(other.top_right())
&& self.contains(other.bottom_left())
&& self.contains(other.bottom_right())
}
pub fn new_xywh(x: T, y: T, width: T, height: T) -> Self {
Self {
x,
y,
width,
height,
}
}
pub fn new_xyxy(x1: T, y1: T, x2: T, y2: T) -> Self
where
T: core::ops::Sub<Output = T> + Copy,
{
Self {
x: x1,
y: y1,
width: x2 - x1,
height: y2 - y1,
}
}
pub fn containing(box1: Self, box2: Self) -> Self
where
T: traits::min::Min + traits::max::Max + Copy,
T: core::ops::Sub<Output = T>,
T: core::ops::Add<Output = T>,
{
let x1 = box1.x.min(box2.x);
let y1 = box1.y.min(box2.y);
let x2 = box1.x2().max(box2.x2());
let y2 = box1.y2().max(box2.y2());
Self::new_xyxy(x1, y1, x2, y2)
}
}
impl<T: core::ops::Sub<Output = T> + Copy> core::ops::Sub<T> for BBox<T> {
type Output = BBox<T>;
fn sub(self, rhs: T) -> Self::Output {
BBox {
x: self.x - rhs,
y: self.y - rhs,
width: self.width - rhs,
height: self.height - rhs,
}
}
}
impl<T: core::ops::Add<Output = T> + Copy> core::ops::Add<T> for BBox<T> {
type Output = BBox<T>;
fn add(self, rhs: T) -> Self::Output {
BBox {
x: self.x + rhs,
y: self.y + rhs,
width: self.width + rhs,
height: self.height + rhs,
}
}
}
impl<T: core::ops::Mul<Output = T> + Copy> core::ops::Mul<T> for BBox<T> {
type Output = BBox<T>;
fn mul(self, rhs: T) -> Self::Output {
BBox {
x: self.x * rhs,
y: self.y * rhs,
width: self.width * rhs,
height: self.height * rhs,
}
}
}
impl<T: core::ops::Div<Output = T> + Copy> core::ops::Div<T> for BBox<T> {
type Output = BBox<T>;
fn div(self, rhs: T) -> Self::Output {
BBox {
x: self.x / rhs,
y: self.y / rhs,
width: self.width / rhs,
height: self.height / rhs,
}
}
}
impl<T> core::ops::Add<BBox<T>> for BBox<T>
where
T: core::ops::Sub<Output = T>
+ core::ops::Add<Output = T>
+ traits::min::Min
+ traits::max::Max
+ Copy,
{
type Output = BBox<T>;
fn add(self, rhs: BBox<T>) -> Self::Output {
let x1 = self.x1().min(rhs.x1());
let y1 = self.y1().min(rhs.y1());
let x2 = self.x2().max(rhs.x2());
let y2 = self.y2().max(rhs.y2());
BBox::new_xyxy(x1, y1, x2, y2)
}
}
#[test]
fn test_bbox_add() {
let bbox1: BBox<usize> = BBox::new_xyxy(0, 0, 10, 10);
let bbox2: BBox<usize> = BBox::new_xyxy(5, 5, 15, 15);
let bbox3: BBox<usize> = bbox1 + bbox2;
assert_eq!(bbox3, BBox::new_xyxy(0, 0, 15, 15).cast());
}
#[derive(
Debug, Copy, Clone, serde::Serialize, serde::Deserialize, PartialEq, PartialOrd, Eq, Ord, Hash,
)]
pub struct Point<T = f32> {
x: T,
y: T,
}
impl<T> Point<T> {
pub const fn new(x: T, y: T) -> Self {
Self { x, y }
}
pub const fn x(&self) -> T
where
T: Copy,
{
self.x
}
pub const fn y(&self) -> T
where
T: Copy,
{
self.y
}
pub fn cast<T2>(&self) -> Point<T2>
where
T: num::cast::AsPrimitive<T2>,
T2: Copy + 'static,
{
Point {
x: self.x.as_(),
y: self.y.as_(),
}
}
}
impl<T: core::ops::Sub<T, Output = T> + Copy> core::ops::Sub<Point<T>> for Point<T> {
type Output = Point<T>;
fn sub(self, rhs: Point<T>) -> Self::Output {
Point {
x: self.x - rhs.x,
y: self.y - rhs.y,
}
}
}
impl<T: core::ops::Add<T, Output = T> + Copy> core::ops::Add<Point<T>> for Point<T> {
type Output = Point<T>;
fn add(self, rhs: Point<T>) -> Self::Output {
Point {
x: self.x + rhs.x,
y: self.y + rhs.y,
}
}
}
impl<T: core::ops::Sub<Output = T> + Copy> Point<T> {
/// If both the boxes are in the same scale then make the translation of the origin to the
/// other box
pub fn with_origin(&self, origin: Self) -> Self {
*self - origin
}
}
impl<T: core::ops::Add<Output = T> + Copy> Point<T> {
pub fn translate(&self, point: Point<T>) -> Self {
*self + point
}
}
impl<I: num::Zero> BBox<I>
where
I: num::cast::AsPrimitive<usize>,
{
pub fn zeros_ndarray_2d<T: num::Zero + Copy>(&self) -> ndarray::Array2<T> {
ndarray::Array2::<T>::zeros((self.height.as_(), self.width.as_()))
}
pub fn zeros_ndarray_3d<T: num::Zero + Copy>(&self, channels: usize) -> ndarray::Array3<T> {
ndarray::Array3::<T>::zeros((self.height.as_(), self.width.as_(), channels))
}
pub fn ones_ndarray_2d<T: num::One + Copy>(&self) -> ndarray::Array2<T> {
ndarray::Array2::<T>::ones((self.height.as_(), self.width.as_()))
}
}
impl<T: num::Float> BBox<T> {
pub fn round(&self) -> Self {
Self {
x: self.x.round(),
y: self.y.round(),
width: self.width.round(),
height: self.height.round(),
}
}
}
#[cfg(test)]
mod bbox_clamp_tests {
use super::*;
#[test]
pub fn bbox_test_clamp_box() {
let large_box = BBox::new(0, 0, 100, 100);
let small_box = BBox::new(10, 10, 20, 20);
let clamped = large_box.clamp_box(small_box);
assert_eq!(clamped, small_box);
}
#[test]
pub fn bbox_test_clamp_box_offset() {
let box_a = BBox::new(0, 0, 100, 100);
let box_b = BBox::new(-10, -10, 20, 20);
let clamped = box_b.clamp_box(box_a);
let expected = BBox::new(0, 0, 10, 10);
assert_eq!(expected, clamped);
}
}
#[cfg(test)]
mod bbox_padding_tests {
use super::*;
#[test]
pub fn bbox_test_padding() {
let bbox = BBox::new(0, 0, 10, 10);
let padded = bbox.padding(2);
assert_eq!(padded, BBox::new(-2, -2, 14, 14));
}
#[test]
pub fn bbox_test_padding_height() {
let bbox = BBox::new(0, 0, 10, 10);
let padded = bbox.padding_height(2);
assert_eq!(padded, BBox::new(0, -2, 10, 14));
}
#[test]
pub fn bbox_test_padding_width() {
let bbox = BBox::new(0, 0, 10, 10);
let padded = bbox.padding_width(2);
assert_eq!(padded, BBox::new(-2, 0, 14, 10));
}
#[test]
pub fn bbox_test_clamped_padding() {
let bbox = BBox::new(0, 0, 10, 10);
let padded = bbox.padding(2);
let clamp = BBox::new(0, 0, 12, 12);
let clamped = padded.clamp_box(clamp);
assert_eq!(clamped, clamp);
}
#[test]
pub fn bbox_clamp_failure() {
let og = BBox::new(475.0, 79.625, 37.0, 282.15);
let padded = BBox {
x: 471.3,
y: 51.412499999999994,
width: 40.69999999999999,
height: 338.54999999999995,
};
let clamp = BBox::new(0.0, 0.0, 512.0, 512.0);
let sus = padded.clamp_box(clamp);
assert!(clamp.contains_bbox(sus));
}
}
#[cfg(test)]
mod bbox_scale_tests {
use super::*;
#[test]
pub fn bbox_test_scale_int() {
let bbox = BBox::new(0, 0, 10, 10);
let scaled = bbox.scale(2);
assert_eq!(scaled, BBox::new(-5, -5, 20, 20));
}
#[test]
pub fn bbox_test_scale_float() {
let bbox = BBox::new(0, 0, 10, 10).cast();
let scaled = bbox.scale(1.05); // 5% increase
let l = 10.0 * 0.05;
assert_eq!(scaled, BBox::new(-l / 2.0, -l / 2.0, 10.0 + l, 10.0 + l));
}
#[test]
pub fn bbox_test_scale_float_negative() {
let bbox = BBox::new(0, 0, 10, 10).cast();
let scaled = bbox.scale(0.95); // 5% decrease
let l = -10.0 * 0.05;
assert_eq!(scaled, BBox::new(-l / 2.0, -l / 2.0, 10.0 + l, 10.0 + l));
}
#[test]
pub fn bbox_scale_float() {
let bbox = BBox::new_xywh(0, 0, 200, 200);
let scaled = bbox.cast::<f64>().scale(1.1).cast::<i32>().clamp(0, 1000);
let expected = BBox::new(0, 0, 220, 220);
assert_eq!(scaled, expected);
}
#[test]
pub fn add_padding_bbox_example() {
// let result = add_padding_bbox(
// vec![Rect::new(100, 200, 300, 400)],
// (0.1, 0.1),
// (1000, 1000),
// );
// assert_eq!(result[0], Rect::new(70, 160, 360, 480));
let bbox = BBox::new(100, 200, 300, 400);
let scaled = bbox.cast::<f64>().scale(1.2).cast::<i32>().clamp(0, 1000);
assert_eq!(bbox, BBox::new(100, 200, 300, 400));
assert_eq!(scaled, BBox::new(70, 160, 360, 480));
}
#[test]
pub fn scale_bboxes() {
// let result = scale_bboxes(Rect::new(100, 200, 300, 400), (1000, 1000), (500, 500));
// assert_eq!(result[0], Rect::new(200, 400, 600, 800));
let bbox = BBox::new(100, 200, 300, 400);
let scaled = bbox.scale(2);
assert_eq!(scaled, BBox::new(200, 400, 600, 800));
}
}

2
bbox/src/traits.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod max;
pub mod min;

27
bbox/src/traits/max.rs Normal file
View File

@@ -0,0 +1,27 @@
pub trait Max: Sized + Copy {
fn max(self, other: Self) -> Self;
}
macro_rules! impl_max {
($($t:ty),*) => {
$(
impl Max for $t {
fn max(self, other: Self) -> Self {
Ord::max(self, other)
}
}
)*
};
(float $($t:ty),*) => {
$(
impl Max for $t {
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
}
)*
};
}
impl_max!(usize, u8, u16, u32, u64, u128, isize, i8, i16, i32, i64, i128);
impl_max!(float f32, f64);

27
bbox/src/traits/min.rs Normal file
View File

@@ -0,0 +1,27 @@
pub trait Min: Sized + Copy {
fn min(self, other: Self) -> Self;
}
macro_rules! impl_min {
($($t:ty),*) => {
$(
impl Min for $t {
fn min(self, other: Self) -> Self {
Ord::min(self, other)
}
}
)*
};
(float $($t:ty),*) => {
$(
impl Min for $t {
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
}
)*
};
}
impl_min!(usize, u8, u16, u32, u64, u128, isize, i8, i16, i32, i64, i128);
impl_min!(float f32, f64);

View File

@@ -163,6 +163,21 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
} }
} }
pub fn scale_uniform(self, scalar: T) -> Self
where
T: core::ops::MulAssign,
T: core::ops::DivAssign,
T: core::ops::SubAssign,
{
let two = T::one() + T::one();
let new_size = self.size * scalar;
let new_point = self.point.coords - (new_size - self.size) / two;
Self {
point: Point::from(new_point),
size: new_size,
}
}
pub fn contains_bbox(&self, other: &Self) -> bool pub fn contains_bbox(&self, other: &Self) -> bool
where where
T: core::ops::AddAssign, T: core::ops::AddAssign,
@@ -270,15 +285,17 @@ impl<T: Num, const D: usize> AxisAlignedBoundingBox<T, D> {
}) })
} }
// pub fn as_<T2>(&self) -> Option<Aabb<T2, D>> pub fn as_<T2>(&self) -> Aabb<T2, D>
// where where
// T2: Num + simba::scalar::SubsetOf<T>, T2: Num,
// { T: num::cast::AsPrimitive<T2>,
// Some(Aabb { {
// point: Point::from(self.point.coords.as_()), Aabb {
// size: self.size.as_(), point: Point::from(self.point.coords.map(|x| x.as_())),
// }) size: self.size.map(|x| x.as_()),
// } }
}
pub fn measure(&self) -> T pub fn measure(&self) -> T
where where
T: core::ops::MulAssign, T: core::ops::MulAssign,

View File

@@ -2,9 +2,6 @@
description = "A simple rust flake using rust-overlay and craneLib"; description = "A simple rust flake using rust-overlay and craneLib";
inputs = { inputs = {
self = {
lfs = true;
};
nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
crane.url = "github:ipetkov/crane"; crane.url = "github:ipetkov/crane";
@@ -46,6 +43,8 @@
system: let system: let
pkgs = import nixpkgs { pkgs = import nixpkgs {
inherit system; inherit system;
config.allowUnfree = true;
config.cudaSupport = pkgs.stdenv.isLinux;
overlays = [ overlays = [
rust-overlay.overlays.default rust-overlay.overlays.default
(final: prev: { (final: prev: {
@@ -78,7 +77,7 @@
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain; craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools; craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;
ort_static = pkgs.onnxruntime.overrideAttrs (old: { ort_static = (pkgs.onnxruntime.overide {cudaSupport = true;}).overrideAttrs (old: {
cmakeFlags = cmakeFlags =
old.cmakeFlags old.cmakeFlags
++ [ ++ [
@@ -114,16 +113,17 @@
stdenv = p: p.clangStdenv; stdenv = p: p.clangStdenv;
doCheck = false; doCheck = false;
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib"; LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
ORT_LIB_LOCATION = "${patchedOnnxruntime}"; # ORT_LIB_LOCATION = "${patchedOnnxruntime}";
ORT_ENV_SYSTEM_LIB_LOCATION = "${patchedOnnxruntime}/lib"; # ORT_ENV_SYSTEM_LIB_LOCATION = "${patchedOnnxruntime}/lib";
ORT_ENV_PREFER_DYNAMIC_LINK = true; # ORT_ENV_PREFER_DYNAMIC_LINK = true;
nativeBuildInputs = with pkgs; [ nativeBuildInputs = with pkgs; [
cmake cmake
pkg-config pkg-config
]; ];
buildInputs = with pkgs; buildInputs = with pkgs;
[ [
# onnxruntime patchedOnnxruntime
sqlite
] ]
++ (lib.optionals pkgs.stdenv.isDarwin [ ++ (lib.optionals pkgs.stdenv.isDarwin [
libiconv libiconv
@@ -200,20 +200,55 @@
devShells = { devShells = {
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} ( default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (
commonArgs commonArgs
// { // rec {
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver"; LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
LD_LIBRARY_PATH = "$LD_LIBRARY_PATH:${builtins.toString (pkgs.lib.makeLibraryPath packages)}";
packages = with pkgs; packages = with pkgs;
[ [
stableToolchainWithRustAnalyzer stableToolchainWithRustAnalyzer
cargo-expand
cargo-outdated
cargo-nextest cargo-nextest
cargo-deny cargo-deny
cmake cmake
mnn mnn
cargo-make cargo-make
hyperfine hyperfine
opencv
uv
# (python312.withPackages (ps:
# with ps; [
# numpy
# matplotlib
# scikit-learn
# opencv-python
# seaborn
# torch
# torchvision
# tensorflow-lite
# retinaface
# facenet-pytorch
# tqdm
# pillow
# orjson
# huggingface-hub
# # insightface
# ]))
] ]
++ (lib.optionals pkgs.stdenv.isDarwin [ ++ (lib.optionals pkgs.stdenv.isDarwin [
apple-sdk_13 apple-sdk_13
])
++ (lib.optionals pkgs.stdenv.isLinux [
xorg.libX11
xorg.libXcursor
xorg.libXrandr
xorg.libXi
xorg.libxcb
libxkbcommon
vulkan-loader
wayland
zenity
cudatoolkit
]); ]);
} }
); );

View File

@@ -9,5 +9,5 @@ open:
bench: bench:
cargo build --release cargo build --release
BINARY="" hyperfine --warmup 3 --export-markdown benchmark.md \ BINARY="" hyperfine --warmup 3 --export-markdown benchmark.md \
"$CARGO_TARGET_DIR/release/detector detect -f coreml selfie.jpg" \ "$CARGO_TARGET_DIR/release/detector detect -f cpu selfie.jpg" \
"$CARGO_TARGET_DIR/release/detector detect -f coreml -b 16 selfie.jpg" "$CARGO_TARGET_DIR/release/detector detect -f cpu -b 1 selfie.jpg"

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,11 @@
[package]
name = "ndarray-safetensors"
version.workspace = true
edition.workspace = true
[dependencies]
bytemuck = { version = "1.23.2" }
half = { version = "2.6.0", default-features = false, features = ["bytemuck"] }
ndarray = { version = "0.16.1", default-features = false, features = ["std"] }
safetensors = "0.6.2"
thiserror = "2.0.15"

View File

@@ -0,0 +1,449 @@
//! # ndarray-serialize
//!
//! A Rust library for serializing and deserializing `ndarray` arrays using the SafeTensors format.
//!
//! ## Features
//! - Serialize `ndarray::ArrayView` to SafeTensors format
//! - Deserialize SafeTensors data back to `ndarray::ArrayView`
//! - Support for multiple data types (f32, f64, i8-i64, u8-u64, f16, bf16)
//! - Zero-copy deserialization when possible
//! - Metadata support
//!
//! ## Example
//! ```rust
//! use ndarray::Array2;
//! use ndarray_safetensors::{SafeArrays, SafeArrayView};
//!
//! // Create some data
//! let array = Array2::<f32>::zeros((3, 4));
//!
//! // Serialize
//! let mut safe_arrays = SafeArrays::new();
//! safe_arrays.insert_ndarray("my_tensor", array.view()).unwrap();
//! safe_arrays.insert_metadata("author", "example");
//! let bytes = safe_arrays.serialize().unwrap();
//!
//! // Deserialize
//! let view = SafeArrayView::from_bytes(&bytes).unwrap();
//! let tensor: ndarray::ArrayView2<f32> = view.tensor("my_tensor").unwrap();
//! assert_eq!(tensor.shape(), &[3, 4]);
//! ```
use safetensors::View;
use std::borrow::Cow;
use std::collections::{BTreeMap, HashMap};
use thiserror::Error;
/// Errors that can occur during SafeTensor operations
#[derive(Error, Debug)]
pub enum SafeTensorError {
#[error("Tensor not found: {0}")]
TensorNotFound(String),
#[error("Invalid tensor data: Got {0} Expected: {1}")]
InvalidTensorData(&'static str, String),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("Safetensor error: {0}")]
SafeTensor(#[from] safetensors::SafeTensorError),
#[error("ndarray::ShapeError error: {0}")]
NdarrayShapeError(#[from] ndarray::ShapeError),
}
type Result<T, E = SafeTensorError> = core::result::Result<T, E>;
use safetensors::tensor::SafeTensors;
/// A view into SafeTensors data that provides access to ndarray tensors
///
/// # Example
/// ```rust
/// use ndarray::Array2;
/// use ndarray_safetensors::{SafeArrays, SafeArrayView};
///
/// let array = Array2::<f32>::ones((2, 3));
/// let mut safe_arrays = SafeArrays::new();
/// safe_arrays.insert_ndarray("data", array.view()).unwrap();
/// let bytes = safe_arrays.serialize().unwrap();
///
/// let view = SafeArrayView::from_bytes(&bytes).unwrap();
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
/// ```
#[derive(Debug)]
pub struct SafeArraysView<'a> {
pub tensors: SafeTensors<'a>,
}
impl<'a> SafeArraysView<'a> {
fn new(tensors: SafeTensors<'a>) -> Self {
Self { tensors }
}
/// Create a SafeArrayView from serialized bytes
pub fn from_bytes(bytes: &'a [u8]) -> Result<SafeArraysView<'a>> {
let tensors = SafeTensors::deserialize(bytes)?;
Ok(Self::new(tensors))
}
/// Get a dynamic-dimensional tensor by name
pub fn dynamic_tensor<T: STDtype>(&self, name: &str) -> Result<ndarray::ArrayViewD<'a, T>> {
self.tensors
.tensor(name)
.map(|tensor| tensor_view_to_array_view(tensor))?
}
/// Get a tensor with specific dimensions by name
///
/// # Example
/// ```rust
/// # use ndarray::Array2;
/// # use ndarray_safetensors::{SafeArrays, SafeArrayView};
/// # let array = Array2::<f32>::ones((2, 3));
/// # let mut safe_arrays = SafeArrays::new();
/// # safe_arrays.insert_ndarray("data", array.view()).unwrap();
/// # let bytes = safe_arrays.serialize().unwrap();
/// # let view = SafeArrayView::from_bytes(&bytes).unwrap();
/// let tensor: ndarray::ArrayView2<f32> = view.tensor("data").unwrap();
/// ```
pub fn tensor<T: STDtype, Dim: ndarray::Dimension>(
&self,
name: &str,
) -> Result<ndarray::ArrayView<'a, T, Dim>> {
Ok(self
.tensors
.tensor(name)
.map(|tensor| tensor_view_to_array_view(tensor))?
.map(|array_view| array_view.into_dimensionality::<Dim>())??)
}
pub fn tensor_by_index<T: STDtype, Dim: ndarray::Dimension>(
&self,
index: usize,
) -> Result<ndarray::ArrayView<'a, T, Dim>> {
self.tensors
.iter()
.nth(index)
.ok_or(SafeTensorError::TensorNotFound(format!(
"Index {} out of bounds",
index
)))
.map(|(_, tensor)| tensor_view_to_array_view(tensor))?
.map(|array_view| array_view.into_dimensionality::<Dim>())?
.map_err(SafeTensorError::NdarrayShapeError)
}
/// Get an iterator over tensor names
pub fn names(&self) -> std::vec::IntoIter<&str> {
self.tensors.names().into_iter()
}
/// Get the number of tensors
pub fn len(&self) -> usize {
self.tensors.len()
}
/// Check if there are no tensors
pub fn is_empty(&self) -> bool {
self.tensors.is_empty()
}
}
/// Trait for types that can be stored in SafeTensors
///
/// Implemented for: f32, f64, i8, i16, i32, i64, u8, u16, u32, u64, f16, bf16
pub trait STDtype: bytemuck::Pod {
fn dtype() -> safetensors::tensor::Dtype;
fn size() -> usize {
(Self::dtype().bitsize() / 8).max(1)
}
}
macro_rules! impl_dtype {
($($t:ty => $dtype:expr),* $(,)?) => {
$(
impl STDtype for $t {
fn dtype() -> safetensors::tensor::Dtype {
$dtype
}
}
)*
};
}
use safetensors::tensor::Dtype;
impl_dtype!(
// bool => Dtype::BOOL, // idk if ndarray::ArrayD<bool> is packed
f32 => Dtype::F32,
f64 => Dtype::F64,
i8 => Dtype::I8,
i16 => Dtype::I16,
i32 => Dtype::I32,
i64 => Dtype::I64,
u8 => Dtype::U8,
u16 => Dtype::U16,
u32 => Dtype::U32,
u64 => Dtype::U64,
half::f16 => Dtype::F16,
half::bf16 => Dtype::BF16,
);
fn tensor_view_to_array_view<'a, T: STDtype>(
tensor: safetensors::tensor::TensorView<'a>,
) -> Result<ndarray::ArrayViewD<'a, T>> {
let shape = tensor.shape();
let dtype = tensor.dtype();
if T::dtype() != dtype {
return Err(SafeTensorError::InvalidTensorData(
core::any::type_name::<T>(),
dtype.to_string(),
));
}
let data = tensor.data();
let data: &[T] = bytemuck::cast_slice(data);
let array = ndarray::ArrayViewD::from_shape(shape, data)?;
Ok(array)
}
/// Builder for creating SafeTensors data from ndarray tensors
///
/// # Example
/// ```rust
/// use ndarray::{Array1, Array2};
/// use ndarray_safetensors::SafeArrays;
///
/// let mut safe_arrays = SafeArrays::new();
///
/// let array1 = Array1::<f32>::from_vec(vec![1.0, 2.0, 3.0]);
/// let array2 = Array2::<i32>::zeros((2, 2));
///
/// safe_arrays.insert_ndarray("vector", array1.view()).unwrap();
/// safe_arrays.insert_ndarray("matrix", array2.view()).unwrap();
/// safe_arrays.insert_metadata("version", "1.0");
///
/// let bytes = safe_arrays.serialize().unwrap();
/// ```
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct SafeArrays<'a> {
pub tensors: BTreeMap<String, SafeArray<'a>>,
pub metadata: Option<HashMap<String, String>>,
}
impl<'a, K: AsRef<str>> FromIterator<(K, SafeArray<'a>)> for SafeArrays<'a> {
fn from_iter<T: IntoIterator<Item = (K, SafeArray<'a>)>>(iter: T) -> Self {
let tensors = iter
.into_iter()
.map(|(k, v)| (k.as_ref().to_owned(), v))
.collect();
Self {
tensors,
metadata: None,
}
}
}
impl<'a, K: AsRef<str>, T: IntoIterator<Item = (K, SafeArray<'a>)>> From<T> for SafeArrays<'a> {
fn from(iter: T) -> Self {
let tensors = iter
.into_iter()
.map(|(k, v)| (k.as_ref().to_owned(), v))
.collect();
Self {
tensors,
metadata: None,
}
}
}
impl<'a> SafeArrays<'a> {
/// Create a SafeArrays from an iterator of (name, ndarray::ArrayView) pairs
/// ```rust
/// use ndarray::{Array2, Array3};
/// use ndarray_safetensors::{SafeArrays, SafeArray};
/// let array = Array2::<f32>::zeros((3, 4));
/// let safe_arrays = SafeArrays::from_ndarrays(vec![
/// ("test_tensor", array.view()),
/// ("test_tensor2", array.view()),
/// ]).unwrap();
/// ```
pub fn from_ndarrays<
K: AsRef<str>,
T: STDtype,
D: ndarray::Dimension + 'a,
I: IntoIterator<Item = (K, ndarray::ArrayView<'a, T, D>)>,
>(
iter: I,
) -> Result<Self> {
let tensors = iter
.into_iter()
.map(|(k, v)| Ok((k.as_ref().to_owned(), SafeArray::from_ndarray(v)?)))
.collect::<Result<BTreeMap<String, SafeArray<'a>>>>()?;
Ok(Self {
tensors,
metadata: None,
})
}
}
// impl<'a, K: AsRef<str>, T: IntoIterator<Item = (K, SafeArray<'a>)>> From<T> for SafeArrays<'a> {
// fn from(iter: T) -> Self {
// let tensors = iter
// .into_iter()
// .map(|(k, v)| (k.as_ref().to_owned(), v))
// .collect();
// Self {
// tensors,
// metadata: None,
// }
// }
// }
impl<'a> SafeArrays<'a> {
/// Create a new empty SafeArrays builder
pub const fn new() -> Self {
Self {
tensors: BTreeMap::new(),
metadata: None,
}
}
/// Insert a SafeArray tensor with the given name
pub fn insert_tensor<'b: 'a>(&mut self, name: impl AsRef<str>, tensor: SafeArray<'b>) {
self.tensors.insert(name.as_ref().to_owned(), tensor);
}
/// Insert an ndarray tensor with the given name
///
/// The array must be in standard layout and contiguous.
pub fn insert_ndarray<'b: 'a, T: STDtype, D: ndarray::Dimension + 'a>(
&mut self,
name: impl AsRef<str>,
array: ndarray::ArrayView<'b, T, D>,
) -> Result<()> {
self.insert_tensor(name, SafeArray::from_ndarray(array)?);
Ok(())
}
/// Insert metadata key-value pair
pub fn insert_metadata(&mut self, key: impl AsRef<str>, value: impl AsRef<str>) {
self.metadata
.get_or_insert_default()
.insert(key.as_ref().to_owned(), value.as_ref().to_owned());
}
/// Serialize all tensors and metadata to bytes
pub fn serialize(self) -> Result<Vec<u8>> {
let out = safetensors::serialize(self.tensors, self.metadata)
.map_err(SafeTensorError::SafeTensor)?;
Ok(out)
}
}
/// A tensor that can be serialized to SafeTensors format
#[derive(Debug, Clone)]
pub struct SafeArray<'a> {
data: Cow<'a, [u8]>,
shape: Vec<usize>,
dtype: safetensors::tensor::Dtype,
}
impl View for SafeArray<'_> {
fn dtype(&self) -> safetensors::tensor::Dtype {
self.dtype
}
fn shape(&self) -> &[usize] {
&self.shape
}
fn data(&self) -> Cow<'_, [u8]> {
self.data.clone()
}
fn data_len(&self) -> usize {
self.data.len()
}
}
impl<'a> SafeArray<'a> {
fn from_ndarray<'b: 'a, T: STDtype, D: ndarray::Dimension + 'a>(
array: ndarray::ArrayView<'b, T, D>,
) -> Result<Self> {
let shape = array.shape().to_vec();
let dtype = T::dtype();
if array.ndim() == 0 {
return Err(SafeTensorError::InvalidTensorData(
core::any::type_name::<T>(),
"Cannot insert a scalar tensor".to_string(),
));
}
if !array.is_standard_layout() {
return Err(SafeTensorError::InvalidTensorData(
core::any::type_name::<T>(),
"ArrayView is not standard layout".to_string(),
));
}
let data =
bytemuck::cast_slice(array.to_slice().ok_or(SafeTensorError::InvalidTensorData(
core::any::type_name::<T>(),
"ArrayView is not contiguous".to_string(),
))?);
let safe_array = SafeArray {
data: Cow::Borrowed(data),
shape,
dtype,
};
Ok(safe_array)
}
}
#[test]
fn test_safe_array_from_ndarray() {
use ndarray::Array2;
let array = Array2::<f32>::zeros((3, 4));
let safe_array = SafeArray::from_ndarray(array.view()).unwrap();
assert_eq!(safe_array.shape, vec![3, 4]);
assert_eq!(safe_array.dtype, safetensors::tensor::Dtype::F32);
assert_eq!(safe_array.data.len(), 3 * 4 * 4); // 3x4x4 bytes for f32
}
#[test]
fn test_serialize_safe_arrays() {
use ndarray::{Array2, Array3};
let mut safe_arrays = SafeArrays::new();
let array = Array2::<f32>::zeros((3, 4));
let array2 = Array3::<u16>::zeros((8, 1, 9));
safe_arrays
.insert_ndarray("test_tensor", array.view())
.unwrap();
safe_arrays
.insert_ndarray("test_tensor2", array2.view())
.unwrap();
safe_arrays.insert_metadata("author", "example");
let serialized = safe_arrays.serialize().unwrap();
assert!(!serialized.is_empty());
// Deserialize to check if it works
let deserialized = SafeArraysView::from_bytes(&serialized).unwrap();
assert_eq!(deserialized.len(), 2);
assert_eq!(
deserialized
.tensor::<f32, ndarray::Ix2>("test_tensor")
.unwrap()
.shape(),
&[3, 4]
);
assert_eq!(
deserialized
.tensor::<u16, ndarray::Ix3>("test_tensor2")
.unwrap()
.shape(),
&[8, 1, 9]
);
}

36
ndcv-bridge/Cargo.toml Normal file
View File

@@ -0,0 +1,36 @@
[package]
name = "ndcv-bridge"
version.workspace = true
edition.workspace = true
[dependencies]
bounding-box.workspace = true
nalgebra.workspace = true
bytemuck.workspace = true
error-stack.workspace = true
fast_image_resize.workspace = true
ndarray = { workspace = true, features = ["rayon"] }
num.workspace = true
opencv = { workspace = true, optional = true }
rayon = "1.10.0"
thiserror.workspace = true
tracing = "0.1.41"
wide = "0.7.32"
img-parts.workspace = true
[dev-dependencies]
divan.workspace = true
ndarray-npy.workspace = true
[features]
opencv = ["dep:opencv"]
default = ["opencv"]
[[bench]]
name = "conversions"
harness = false
[[bench]]
name = "gaussian"
harness = false

View File

@@ -0,0 +1,75 @@
use divan::black_box;
use ndcv_bridge::*;
// #[global_allocator]
// static ALLOC: AllocProfiler = AllocProfiler::system();
fn main() {
divan::main();
}
#[divan::bench]
fn bench_3d_mat_to_ndarray_512() {
bench_mat_to_3d_ndarray(512);
}
#[divan::bench]
fn bench_3d_mat_to_ndarray_1024() {
bench_mat_to_3d_ndarray(1024);
}
#[divan::bench]
fn bench_3d_mat_to_ndarray_2k() {
bench_mat_to_3d_ndarray(2048);
}
#[divan::bench]
fn bench_3d_mat_to_ndarray_4k() {
bench_mat_to_3d_ndarray(4096);
}
#[divan::bench]
fn bench_3d_mat_to_ndarray_8k() {
bench_mat_to_3d_ndarray(8192);
}
#[divan::bench]
fn bench_3d_mat_to_ndarray_8k_ref() {
bench_mat_to_3d_ndarray_ref(8192);
}
#[divan::bench]
fn bench_2d_mat_to_ndarray_8k_ref() {
bench_mat_to_2d_ndarray(8192);
}
fn bench_mat_to_2d_ndarray(size: i32) -> ndarray::Array2<u8> {
let mat =
opencv::core::Mat::new_nd_with_default(&[size, size], opencv::core::CV_8UC1, (200).into())
.expect("failed");
let ndarray: ndarray::Array2<u8> = mat.as_ndarray().expect("failed").to_owned();
ndarray
}
fn bench_mat_to_3d_ndarray(size: i32) -> ndarray::Array3<u8> {
let mat = opencv::core::Mat::new_nd_with_default(
&[size, size],
opencv::core::CV_8UC3,
(200, 100, 10).into(),
)
.expect("failed");
// ndarray::Array3::<u8>::from_mat(black_box(mat)).expect("failed")
let ndarray: ndarray::Array3<u8> = mat.as_ndarray().expect("failed").to_owned();
ndarray
}
fn bench_mat_to_3d_ndarray_ref(size: i32) {
let mut mat = opencv::core::Mat::new_nd_with_default(
&[size, size],
opencv::core::CV_8UC3,
(200, 100, 10).into(),
)
.expect("failed");
let array: ndarray::ArrayView3<u8> = black_box(&mut mat).as_ndarray().expect("failed");
let _ = black_box(array);
}

View File

@@ -0,0 +1,265 @@
use divan::black_box;
use ndarray::*;
use ndcv_bridge::*;
// #[global_allocator]
// static ALLOC: AllocProfiler = AllocProfiler::system();
fn main() {
divan::main();
}
// Helper function to create test images with different patterns
fn create_test_image(size: usize, pattern: &str) -> Array3<u8> {
let mut arr = Array3::<u8>::zeros((size, size, 3));
match pattern {
"edges" => {
// Create a pattern with sharp edges
arr.slice_mut(s![size / 4..3 * size / 4, size / 4..3 * size / 4, ..])
.fill(255);
}
"gradient" => {
// Create a gradual gradient
for i in 0..size {
let val = (i * 255 / size) as u8;
arr.slice_mut(s![i, .., ..]).fill(val);
}
}
"checkerboard" => {
// Create a checkerboard pattern
for i in 0..size {
for j in 0..size {
if (i / 20 + j / 20) % 2 == 0 {
arr[[i, j, 0]] = 255;
arr[[i, j, 1]] = 255;
arr[[i, j, 2]] = 255;
}
}
}
}
_ => arr.fill(255), // Default to solid white
}
arr
}
#[divan::bench_group]
mod sizes {
use super::*;
// Benchmark different image sizes
#[divan::bench(args = [512, 1024, 2048, 4096])]
fn bench_gaussian_sizes_u8(size: usize) {
let arr = Array3::<u8>::ones((size, size, 3));
let _out = black_box(
arr.gaussian_blur((3, 3), 1.0, 1.0, BorderType::BorderConstant)
.unwrap(),
);
}
#[divan::bench(args = [512, 1024, 2048, 4096])]
fn bench_gaussian_sizes_u8_inplace(size: usize) {
let mut arr = Array3::<u8>::ones((size, size, 3));
black_box(
arr.gaussian_blur_inplace((3, 3), 1.0, 1.0, BorderType::BorderConstant)
.unwrap(),
);
}
#[divan::bench(args = [512, 1024, 2048, 4096])]
fn bench_gaussian_sizes_f32(size: usize) {
let arr = Array3::<f32>::ones((size, size, 3));
let _out = black_box(
arr.gaussian_blur((3, 3), 1.0, 1.0, BorderType::BorderConstant)
.unwrap(),
);
}
#[divan::bench(args = [512, 1024, 2048, 4096])]
fn bench_gaussian_sizes_f32_inplace(size: usize) {
let mut arr = Array3::<f32>::ones((size, size, 3));
black_box(
arr.gaussian_blur_inplace((3, 3), 1.0, 1.0, BorderType::BorderConstant)
.unwrap(),
);
}
}
// Benchmark different kernel sizes
#[divan::bench(args = [(3, 3), (5, 5), (7, 7), (9, 9), (11, 11)])]
fn bench_gaussian_kernels(kernel_size: (u8, u8)) {
let mut arr = Array3::<u8>::ones((1000, 1000, 3));
arr.gaussian_blur_inplace(kernel_size, 1.0, 1.0, BorderType::BorderConstant)
.unwrap();
}
// Benchmark different sigma values
#[divan::bench(args = [0.5, 1.0, 2.0, 5.0])]
fn bench_gaussian_sigmas(sigma: f64) {
let mut arr = Array3::<u8>::ones((1000, 1000, 3));
arr.gaussian_blur_inplace((3, 3), sigma, sigma, BorderType::BorderConstant)
.unwrap();
}
// Benchmark different sigma_x and sigma_y combinations
#[divan::bench(args = [(0.5, 2.0), (1.0, 1.0), (2.0, 0.5), (3.0, 1.0)])]
fn bench_gaussian_asymmetric_sigmas(sigmas: (f64, f64)) {
let mut arr = Array3::<u8>::ones((1000, 1000, 3));
arr.gaussian_blur_inplace((3, 3), sigmas.0, sigmas.1, BorderType::BorderConstant)
.unwrap();
}
// Benchmark different border types
#[divan::bench]
fn bench_gaussian_border_types() -> Vec<()> {
let border_types = [
BorderType::BorderConstant,
BorderType::BorderReplicate,
BorderType::BorderReflect,
BorderType::BorderReflect101,
];
let mut arr = Array3::<u8>::ones((1000, 1000, 3));
border_types
.iter()
.map(|border_type| {
arr.gaussian_blur_inplace((3, 3), 1.0, 1.0, *border_type)
.unwrap();
})
.collect()
}
// Benchmark different image patterns
#[divan::bench]
fn bench_gaussian_patterns() {
let patterns = ["edges", "gradient", "checkerboard", "solid"];
patterns.iter().for_each(|&pattern| {
let mut arr = create_test_image(1000, pattern);
arr.gaussian_blur_inplace((3, 3), 1.0, 1.0, BorderType::BorderConstant)
.unwrap();
})
}
#[divan::bench_group]
mod allocation {
use super::*;
#[divan::bench]
fn bench_gaussian_allocation_inplace() {
let mut arr = Array3::<f32>::ones((3840, 2160, 3));
black_box(
arr.gaussian_blur_inplace((3, 3), 1.0, 1.0, BorderType::BorderConstant)
.unwrap(),
);
}
#[divan::bench]
fn bench_gaussian_allocation_allocate() {
let arr = Array3::<f32>::ones((3840, 2160, 3));
let _out = black_box(
arr.gaussian_blur((3, 3), 1.0, 1.0, BorderType::BorderConstant)
.unwrap(),
);
}
}
#[divan::bench_group]
mod realistic {
use super::*;
#[divan::bench]
fn small_800_600_3x3() {
let small_blur = Array3::<u8>::ones((800, 600, 3));
let _blurred = black_box(
small_blur
.gaussian_blur((3, 3), 0.5, 0.5, BorderType::BorderConstant)
.unwrap(),
);
}
#[divan::bench]
fn small_800_600_3x3_inplace() {
let mut small_blur = Array3::<u8>::ones((800, 600, 3));
small_blur
.gaussian_blur_inplace((3, 3), 0.5, 0.5, BorderType::BorderConstant)
.unwrap();
}
#[divan::bench]
fn medium_1920x1080_5x5() {
let mut medium_blur = Array3::<u8>::ones((1920, 1080, 3));
let _blurred = black_box(
medium_blur
.gaussian_blur_inplace((5, 5), 2.0, 2.0, BorderType::BorderConstant)
.unwrap(),
);
}
#[divan::bench]
fn medium_1920x1080_5x5_inplace() {
let mut medium_blur = Array3::<u8>::ones((1920, 1080, 3));
medium_blur
.gaussian_blur_inplace((5, 5), 2.0, 2.0, BorderType::BorderConstant)
.unwrap();
}
#[divan::bench]
fn large_3840x2160_9x9() {
let large_blur = Array3::<u8>::ones((3840, 2160, 3));
let _blurred = black_box(
large_blur
.gaussian_blur((9, 9), 5.0, 5.0, BorderType::BorderConstant)
.unwrap(),
);
}
#[divan::bench]
fn large_3840x2160_9x9_inplace() {
let mut large_blur = Array3::<u8>::ones((3840, 2160, 3));
large_blur
.gaussian_blur_inplace((9, 9), 5.0, 5.0, BorderType::BorderConstant)
.unwrap();
}
#[divan::bench]
fn small_800_600_3x3_f32() {
let small_blur = Array3::<f32>::ones((800, 600, 3));
let _blurred = black_box(
small_blur
.gaussian_blur((3, 3), 0.5, 0.5, BorderType::BorderConstant)
.unwrap(),
);
}
#[divan::bench]
fn small_800_600_3x3_inplace_f32() {
let mut small_blur = Array3::<f32>::ones((800, 600, 3));
small_blur
.gaussian_blur_inplace((3, 3), 0.5, 0.5, BorderType::BorderConstant)
.unwrap();
}
#[divan::bench]
fn medium_1920x1080_5x5_f32() {
let mut medium_blur = Array3::<f32>::ones((1920, 1080, 3));
let _blurred = black_box(
medium_blur
.gaussian_blur_inplace((5, 5), 2.0, 2.0, BorderType::BorderConstant)
.unwrap(),
);
}
#[divan::bench]
fn medium_1920x1080_5x5_inplace_f32() {
let mut medium_blur = Array3::<f32>::ones((1920, 1080, 3));
medium_blur
.gaussian_blur_inplace((5, 5), 2.0, 2.0, BorderType::BorderConstant)
.unwrap();
}
#[divan::bench]
fn large_3840x2160_9x9_f32() {
let large_blur = Array3::<f32>::ones((3840, 2160, 3));
let _blurred = black_box(
large_blur
.gaussian_blur((9, 9), 5.0, 5.0, BorderType::BorderConstant)
.unwrap(),
);
}
#[divan::bench]
fn large_3840x2160_9x9_inplace_f32() {
let mut large_blur = Array3::<f32>::ones((3840, 2160, 3));
large_blur
.gaussian_blur_inplace((9, 9), 5.0, 5.0, BorderType::BorderConstant)
.unwrap();
}
}

180
ndcv-bridge/src/blend.rs Normal file
View File

@@ -0,0 +1,180 @@
use crate::prelude_::*;
use ndarray::*;
type Result<T, E = Report<NdCvError>> = std::result::Result<T, E>;
mod seal {
pub trait Sealed {}
impl<T: ndarray::Data<Elem = f32>> Sealed for ndarray::ArrayBase<T, ndarray::Ix3> {}
}
pub trait NdBlend<T, D: ndarray::Dimension>: seal::Sealed {
fn blend(
&self,
mask: ndarray::ArrayView<T, D::Smaller>,
other: ndarray::ArrayView<T, D>,
alpha: T,
) -> Result<ndarray::Array<T, D>>;
fn blend_inplace(
&mut self,
mask: ndarray::ArrayView<T, D::Smaller>,
other: ndarray::ArrayView<T, D>,
alpha: T,
) -> Result<()>;
}
impl<S> NdBlend<f32, Ix3> for ndarray::ArrayBase<S, Ix3>
where
S: ndarray::DataMut<Elem = f32>,
{
fn blend(
&self,
mask: ndarray::ArrayView<f32, Ix2>,
other: ndarray::ArrayView<f32, Ix3>,
alpha: f32,
) -> Result<ndarray::Array<f32, Ix3>> {
if self.shape() != other.shape() {
return Err(NdCvError)
.attach_printable("Shapes of image and other imagge do not match");
}
if self.shape()[0] != mask.shape()[0] || self.shape()[1] != mask.shape()[1] {
return Err(NdCvError).attach_printable("Shapes of image and mask do not match");
}
let mut output = ndarray::Array3::zeros(self.dim());
let (_height, _width, channels) = self.dim();
Zip::from(output.lanes_mut(Axis(2)))
.and(self.lanes(Axis(2)))
.and(other.lanes(Axis(2)))
.and(mask)
.par_for_each(|mut out, this, other, mask| {
let this = wide::f32x4::from(this.as_slice().expect("Invalid self array"));
let other = wide::f32x4::from(other.as_slice().expect("Invalid other array"));
let mask = wide::f32x4::splat(mask * alpha);
let o = this * (1.0 - mask) + other * mask;
out.as_slice_mut()
.expect("Failed to get mutable slice")
.copy_from_slice(&o.as_array_ref()[..channels]);
});
Ok(output)
}
fn blend_inplace(
&mut self,
mask: ndarray::ArrayView<f32, <Ix3 as Dimension>::Smaller>,
other: ndarray::ArrayView<f32, Ix3>,
alpha: f32,
) -> Result<()> {
if self.shape() != other.shape() {
return Err(NdCvError)
.attach_printable("Shapes of image and other imagge do not match");
}
if self.shape()[0] != mask.shape()[0] || self.shape()[1] != mask.shape()[1] {
return Err(NdCvError).attach_printable("Shapes of image and mask do not match");
}
let (_height, _width, channels) = self.dim();
// Zip::from(self.lanes_mut(Axis(2)))
// .and(other.lanes(Axis(2)))
// .and(mask)
// .par_for_each(|mut this, other, mask| {
// let this_wide = wide::f32x4::from(this.as_slice().expect("Invalid self array"));
// let other = wide::f32x4::from(other.as_slice().expect("Invalid other array"));
// let mask = wide::f32x4::splat(mask * alpha);
// let o = this_wide * (1.0 - mask) + other * mask;
// this.as_slice_mut()
// .expect("Failed to get mutable slice")
// .copy_from_slice(&o.as_array_ref()[..channels]);
// });
let this = self
.as_slice_mut()
.ok_or(NdCvError)
.attach_printable("Failed to get source image as a continuous slice")?;
let other = other
.as_slice()
.ok_or(NdCvError)
.attach_printable("Failed to get other image as a continuous slice")?;
let mask = mask
.as_slice()
.ok_or(NdCvError)
.attach_printable("Failed to get mask as a continuous slice")?;
use rayon::prelude::*;
this.par_chunks_exact_mut(channels)
.zip(other.par_chunks_exact(channels))
.zip(mask)
.for_each(|((this, other), mask)| {
let this_wide = wide::f32x4::from(&*this);
let other = wide::f32x4::from(other);
let mask = wide::f32x4::splat(mask * alpha);
this.copy_from_slice(
&(this_wide * (1.0 - mask) + other * mask).as_array_ref()[..channels],
);
});
// for h in 0.._height {
// for w in 0.._width {
// let mask_index = h * _width + w;
// let mask = mask[mask_index];
// let mask = wide::f32x4::splat(mask * alpha);
// let this = &mut this[mask_index * channels..(mask_index + 1) * channels];
// let other = &other[mask_index * channels..(mask_index + 1) * channels];
// let this_wide = wide::f32x4::from(&*this);
// let other = wide::f32x4::from(other);
// let o = this_wide * (1.0 - mask) + other * mask;
// this.copy_from_slice(&o.as_array_ref()[..channels]);
// }
// }
Ok(())
}
}
#[test]
pub fn test_blend() {
let img = Array3::<f32>::from_shape_fn((10, 10, 3), |(i, j, k)| match (i, j, k) {
(0..=3, _, 0) => 1f32, // red
(4..=6, _, 1) => 1f32, // green
(7..=9, _, 2) => 1f32, // blue
_ => 0f32,
});
let other = img.clone().permuted_axes([1, 0, 2]).to_owned();
let mask = Array2::<f32>::from_shape_fn((10, 10), |(_, j)| if j > 5 { 1f32 } else { 0f32 });
// let other = Array3::<f32>::zeros((10, 10, 3));
let out = img.blend(mask.view(), other.view(), 1f32).unwrap();
let out_u8 = out.mapv(|v| (v * 255f32) as u8);
let expected = Array3::<u8>::from_shape_fn((10, 10, 3), |(i, j, k)| {
match (i, j, k) {
(0..=3, 0..=5, 0) => u8::MAX, // red
(4..=6, 0..=5, 1) | (_, 6, 1) => u8::MAX, // green
(7..=9, 0..=5, 2) | (_, 7..=10, 2) => u8::MAX, // blue
_ => u8::MIN,
}
});
assert_eq!(out_u8, expected);
}
// #[test]
// pub fn test_blend_inplace() {
// let mut img = Array3::<f32>::from_shape_fn((10, 10, 3), |(i, j, k)| match (i, j, k) {
// (0..=3, _, 0) => 1f32, // red
// (4..=6, _, 1) => 1f32, // green
// (7..=9, _, 2) => 1f32, // blue
// _ => 0f32,
// });
// let other = img.clone().permuted_axes([1, 0, 2]);
// let mask = Array2::<f32>::from_shape_fn((10, 10), |(_, j)| if j > 5 { 1f32 } else { 0f32 });
// // let other = Array3::<f32>::zeros((10, 10, 3));
// img.blend_inplace(mask.view(), other.view(), 1f32).unwrap();
// let out_u8 = img.mapv(|v| (v * 255f32) as u8);
// let expected = Array3::<u8>::from_shape_fn((10, 10, 3), |(i, j, k)| {
// match (i, j, k) {
// (0..=3, 0..=5, 0) => u8::MAX, // red
// (4..=6, 0..=5, 1) | (_, 6, 1) => u8::MAX, // green
// (7..=9, 0..=5, 2) | (_, 7..=10, 2) => u8::MAX, // blue
// _ => u8::MIN,
// }
// });
// assert_eq!(out_u8, expected);
// }

View File

@@ -0,0 +1,48 @@
//! Calculates the up-right bounding rectangle of a point set or non-zero pixels of gray-scale image.
//! The function calculates and returns the minimal up-right bounding rectangle for the specified point set or non-zero pixels of gray-scale image.
use crate::{NdAsImage, prelude_::*};
pub trait BoundingRect: seal::SealedInternal {
fn bounding_rect(&self) -> Result<bounding_box::Aabb2<i32>, NdCvError>;
}
mod seal {
pub trait SealedInternal {}
impl<T, S: ndarray::Data<Elem = T>> SealedInternal for ndarray::ArrayBase<S, ndarray::Ix2> {}
}
impl<S: ndarray::Data<Elem = u8>> BoundingRect for ndarray::ArrayBase<S, ndarray::Ix2> {
fn bounding_rect(&self) -> Result<bounding_box::Aabb2<i32>, NdCvError> {
let mat = self.as_image_mat()?;
let rect = opencv::imgproc::bounding_rect(mat.as_ref()).change_context(NdCvError)?;
Ok(bounding_box::Aabb2::from_xywh(
rect.x,
rect.y,
rect.width,
rect.height,
))
}
}
#[test]
fn test_bounding_rect_empty() {
let arr = ndarray::Array2::<u8>::zeros((10, 10));
let rect = arr.bounding_rect().unwrap();
assert_eq!(rect, bounding_box::Aabb2::from_xywh(0, 0, 0, 0));
}
#[test]
fn test_bounding_rect_valued() {
let mut arr = ndarray::Array2::<u8>::zeros((10, 10));
crate::NdRoiMut::roi_mut(&mut arr, bounding_box::Aabb2::from_xywh(1, 1, 3, 3)).fill(1);
let rect = arr.bounding_rect().unwrap();
assert_eq!(rect, bounding_box::Aabb2::from_xywh(1, 1, 3, 3));
}
#[test]
fn test_bounding_rect_complex() {
let mut arr = ndarray::Array2::<u8>::zeros((10, 10));
crate::NdRoiMut::roi_mut(&mut arr, bounding_box::Aabb2::from_xywh(1, 3, 3, 3)).fill(1);
crate::NdRoiMut::roi_mut(&mut arr, bounding_box::Aabb2::from_xywh(2, 3, 3, 5)).fill(5);
let rect = arr.bounding_rect().unwrap();
assert_eq!(rect, bounding_box::Aabb2::from_xywh(1, 3, 4, 5));
}

4
ndcv-bridge/src/codec.rs Normal file
View File

@@ -0,0 +1,4 @@
pub mod codecs;
pub mod decode;
pub mod encode;
pub mod error;

View File

@@ -0,0 +1,218 @@
use super::decode::Decoder;
use super::encode::Encoder;
use crate::NdCvError;
use crate::conversions::matref::MatRef;
use error_stack::*;
use img_parts::{
Bytes,
jpeg::{Jpeg, markers},
};
use opencv::{
core::{Mat, Vector, VectorToVec},
imgcodecs::{ImreadModes, ImwriteFlags, imdecode, imencode},
};
#[derive(Debug)]
pub enum CvEncoder {
Jpeg(CvJpegEncFlags),
Tiff(CvTiffEncFlags),
}
pub enum EncKind {
Jpeg,
Tiff,
}
impl CvEncoder {
fn kind(&self) -> EncKind {
match self {
Self::Jpeg(_) => EncKind::Jpeg,
Self::Tiff(_) => EncKind::Tiff,
}
}
fn extension(&self) -> &'static str {
match self {
Self::Jpeg(_) => ".jpg",
Self::Tiff(_) => ".tiff",
}
}
fn to_cv_param_list(&self) -> Vector<i32> {
match self {
Self::Jpeg(flags) => flags.to_cv_param_list(),
Self::Tiff(flags) => flags.to_cv_param_list(),
}
}
}
#[derive(Default, Debug)]
pub struct CvJpegEncFlags {
quality: Option<usize>,
progressive: Option<bool>,
optimize: Option<bool>,
remove_app0: Option<bool>,
}
#[derive(Default, Debug)]
pub struct CvTiffEncFlags {
compression: Option<i32>,
}
impl CvTiffEncFlags {
pub fn new() -> Self {
Self::default().with_compression(1)
}
pub fn with_compression(mut self, compression: i32) -> Self {
self.compression = Some(compression);
self
}
fn to_cv_param_list(&self) -> Vector<i32> {
let iter = [(
ImwriteFlags::IMWRITE_TIFF_COMPRESSION as i32,
self.compression.map(|i| i as i32),
)]
.into_iter()
.filter_map(|(flag, opt)| opt.map(|o| [flag, o]))
.flatten();
Vector::from_iter(iter)
}
}
impl CvJpegEncFlags {
pub fn new() -> Self {
Self::default()
}
pub fn with_quality(mut self, quality: usize) -> Self {
self.quality = Some(quality);
self
}
pub fn remove_app0_marker(mut self, val: bool) -> Self {
self.remove_app0 = Some(val);
self
}
fn to_cv_param_list(&self) -> Vector<i32> {
let iter = [
(
ImwriteFlags::IMWRITE_JPEG_QUALITY as i32,
self.quality.map(|i| i as i32),
),
(
ImwriteFlags::IMWRITE_JPEG_PROGRESSIVE as i32,
self.progressive.map(|i| i as i32),
),
(
ImwriteFlags::IMWRITE_JPEG_OPTIMIZE as i32,
self.optimize.map(|i| i as i32),
),
]
.into_iter()
.filter_map(|(flag, opt)| opt.map(|o| [flag, o]))
.flatten();
Vector::from_iter(iter)
}
}
impl Encoder for CvEncoder {
type Input<'a>
= MatRef<'a>
where
Self: 'a;
fn encode(&self, input: Self::Input<'_>) -> Result<Vec<u8>, NdCvError> {
let mut buf = Vector::default();
let params = self.to_cv_param_list();
imencode(self.extension(), &input.as_ref(), &mut buf, &params).change_context(NdCvError)?;
match self.kind() {
EncKind::Jpeg => {
let bytes = Bytes::from(buf.to_vec());
let mut jpg = Jpeg::from_bytes(bytes).change_context(NdCvError)?;
jpg.remove_segments_by_marker(markers::APP0);
let bytes = jpg.encoder().bytes();
Ok(bytes.to_vec())
}
EncKind::Tiff => Ok(buf.to_vec()),
}
}
}
pub enum CvDecoder {
Jpeg(CvJpegDecFlags),
}
impl CvDecoder {
fn to_cv_decode_flag(&self) -> i32 {
match self {
Self::Jpeg(flags) => flags.to_cv_decode_flag(),
}
}
}
#[derive(Default)]
pub enum ColorMode {
#[default]
Color,
GrayScale,
}
impl ColorMode {
fn to_cv_decode_flag(&self) -> i32 {
match self {
Self::Color => ImreadModes::IMREAD_ANYCOLOR as i32,
Self::GrayScale => ImreadModes::IMREAD_GRAYSCALE as i32,
}
}
}
#[derive(Default)]
pub struct CvJpegDecFlags {
color_mode: ColorMode,
ignore_orientation: bool,
}
impl CvJpegDecFlags {
pub fn new() -> Self {
Self::default()
}
pub fn with_color_mode(mut self, color_mode: ColorMode) -> Self {
self.color_mode = color_mode;
self
}
pub fn with_ignore_orientation(mut self, ignore_orientation: bool) -> Self {
self.ignore_orientation = ignore_orientation;
self
}
fn to_cv_decode_flag(&self) -> i32 {
let flag = self.color_mode.to_cv_decode_flag();
if self.ignore_orientation {
flag | ImreadModes::IMREAD_IGNORE_ORIENTATION as i32
} else {
flag
}
}
}
impl Decoder for CvDecoder {
type Output = Mat;
fn decode(&self, input: impl AsRef<[u8]>) -> Result<Self::Output, NdCvError> {
let flag = self.to_cv_decode_flag();
let out = imdecode(&Vector::from_slice(input.as_ref()), flag).change_context(NdCvError)?;
Ok(out)
}
}

View File

@@ -0,0 +1,61 @@
#![deny(warnings)]
use super::codecs::CvDecoder;
use super::error::ErrorReason;
use crate::NdCvError;
use crate::{NdAsImage, conversions::NdCvConversion};
use error_stack::*;
use ndarray::Array;
use std::path::Path;
pub trait Decodable<D: Decoder>: Sized {
fn decode(buf: impl AsRef<[u8]>, decoder: &D) -> Result<Self, NdCvError> {
let output = decoder.decode(buf)?;
Self::transform(output)
}
fn read(&self, path: impl AsRef<Path>, decoder: &D) -> Result<Self, NdCvError> {
let buf = std::fs::read(path)
.map_err(|e| match e.kind() {
std::io::ErrorKind::NotFound => {
Report::new(e).attach_printable(ErrorReason::ImageWriteFileNotFound)
}
std::io::ErrorKind::PermissionDenied => {
Report::new(e).attach_printable(ErrorReason::ImageWritePermissionDenied)
}
std::io::ErrorKind::OutOfMemory => {
Report::new(e).attach_printable(ErrorReason::OutOfMemory)
}
std::io::ErrorKind::StorageFull => {
Report::new(e).attach_printable(ErrorReason::OutOfStorage)
}
_ => Report::new(e).attach_printable(ErrorReason::ImageWriteOtherError),
})
.change_context(NdCvError)?;
Self::decode(buf, decoder)
}
fn transform(input: D::Output) -> Result<Self, NdCvError>;
}
pub trait Decoder {
type Output: Sized;
fn decode(&self, buf: impl AsRef<[u8]>) -> Result<Self::Output, NdCvError>;
}
impl<T: bytemuck::Pod + Copy, D: ndarray::Dimension> Decodable<CvDecoder> for Array<T, D>
where
Self: NdAsImage<T, D>,
{
fn transform(input: <CvDecoder as Decoder>::Output) -> Result<Self, NdCvError> {
Self::from_mat(input)
}
}
#[test]
fn decode_image() {
use crate::codec::codecs::*;
let img = std::fs::read("/Users/fs0c131y/Projects/face-detector/assets/selfie.jpg").unwrap();
let decoder = CvDecoder::Jpeg(CvJpegDecFlags::new().with_ignore_orientation(true));
let _out = ndarray::Array3::<u8>::decode(img, &decoder).unwrap();
}

View File

@@ -0,0 +1,56 @@
use super::codecs::CvEncoder;
use super::error::ErrorReason;
use crate::conversions::NdAsImage;
use crate::NdCvError;
use error_stack::*;
use ndarray::ArrayBase;
use std::path::Path;
pub trait Encodable<E: Encoder> {
fn encode(&self, encoder: &E) -> Result<Vec<u8>, NdCvError> {
let input = self.transform()?;
encoder.encode(input)
}
fn write(&self, path: impl AsRef<Path>, encoder: &E) -> Result<(), NdCvError> {
let buf = self.encode(encoder)?;
std::fs::write(path, buf)
.map_err(|e| match e.kind() {
std::io::ErrorKind::NotFound => {
Report::new(e).attach_printable(ErrorReason::ImageWriteFileNotFound)
}
std::io::ErrorKind::PermissionDenied => {
Report::new(e).attach_printable(ErrorReason::ImageWritePermissionDenied)
}
std::io::ErrorKind::OutOfMemory => {
Report::new(e).attach_printable(ErrorReason::OutOfMemory)
}
std::io::ErrorKind::StorageFull => {
Report::new(e).attach_printable(ErrorReason::OutOfStorage)
}
_ => Report::new(e).attach_printable(ErrorReason::ImageWriteOtherError),
})
.change_context(NdCvError)
}
fn transform(&self) -> Result<<E as Encoder>::Input<'_>, NdCvError>;
}
pub trait Encoder {
type Input<'a>
where
Self: 'a;
fn encode(&self, input: Self::Input<'_>) -> Result<Vec<u8>, NdCvError>;
}
impl<T: bytemuck::Pod + Copy, S: ndarray::Data<Elem = T>, D: ndarray::Dimension>
Encodable<CvEncoder> for ArrayBase<S, D>
where
Self: NdAsImage<T, D>,
{
fn transform(&self) -> Result<<CvEncoder as Encoder>::Input<'_>, NdCvError> {
self.as_image_mat()
}
}

View File

@@ -0,0 +1,19 @@
#[derive(Debug)]
pub enum ErrorReason {
ImageReadFileNotFound,
ImageReadPermissionDenied,
ImageReadOtherError,
ImageWriteFileNotFound,
ImageWritePermissionDenied,
ImageWriteOtherError,
OutOfMemory,
OutOfStorage,
}
impl std::fmt::Display for ErrorReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}

View File

@@ -0,0 +1,88 @@
//! Colorspace conversion functions
//! ## Example
//! ```rust
//! let arr = Array3::<u8>::ones((100, 100, 3));
//! let out: Array3<u8> = arr.cvt::<Rgba<u8>, Rgb<u8>>()
//! ```
use crate::prelude_::*;
use ndarray::*;
pub trait ColorSpace {
type Elem: seal::Sealed;
type Dim: ndarray::Dimension;
const CHANNELS: usize;
}
mod seal {
pub trait Sealed: bytemuck::Pod {}
// impl<T> Sealed for T {}
impl Sealed for u8 {} // 0 to 255
impl Sealed for u16 {} // 0 to 65535
impl Sealed for f32 {} // 0 to 1
}
macro_rules! define_color_space {
($name:ident, $channels:expr, $depth:ty) => {
pub struct $name<T> {
__phantom: core::marker::PhantomData<T>,
}
impl<T: seal::Sealed> ColorSpace for $name<T> {
type Elem = T;
type Dim = $depth;
const CHANNELS: usize = $channels;
}
};
}
define_color_space!(Rgb, 3, Ix3);
define_color_space!(Bgr, 3, Ix3);
define_color_space!(Rgba, 4, Ix3);
pub trait NdArray<T, D: ndarray::Dimension> {}
impl<T, D: ndarray::Dimension, S: ndarray::Data<Elem = T>> NdArray<S, D> for ArrayBase<S, D> {}
pub trait ConvertColor<T, U>
where
T: ColorSpace,
U: ColorSpace,
Self: NdArray<T::Elem, T::Dim>,
{
type Output: NdArray<U::Elem, U::Dim>;
fn cvt(&self) -> Self::Output;
}
// impl<T: seal::Sealed, S: ndarray::Data<Elem = T>> ConvertColor<Rgb<T>, Bgr<T>> for ArrayBase<S, Ix3>
// where
// Self: NdArray<T, Ix3>,
// {
// type Output = ArrayView3<'a, T>;
// fn cvt(&self) -> CowArray<T, Ix3> {
// self.view().permuted_axes([2, 1, 0]).into()
// }
// }
//
// impl<T: seal::Sealed, S: ndarray::Data<Elem = T>> ConvertColor<Bgr<T>, Rgb<T>> for ArrayBase<S, Ix3>
// where
// Self: NdArray<T, Ix3>,
// {
// type Output = ArrayView3<'a, T>;
// fn cvt(&self) -> CowArray<T, Ix3> {
// self.view().permuted_axes([2, 1, 0]).into()
// }
// }
// impl<T: seal::Sealed + num::One + num::Zero, S: ndarray::Data<Elem = T>>
// ConvertColor<Rgb<T>, Rgba<T>> for ArrayBase<S, Ix3>
// {
// fn cvt(&self) -> CowArray<T, Ix3> {
// let mut out = Array3::<T>::zeros((self.height(), self.width(), 4));
// // Zip::from(&mut out).and(self).for_each(|out, &in_| {
// // out[0] = in_[0];
// // out[1] = in_[1];
// // out[2] = in_[2];
// // out[3] = T::one();
// // });
// out.into()
// }
// }

View File

@@ -0,0 +1,113 @@
use crate::{NdAsImage, NdAsImageMut, conversions::MatAsNd, prelude_::*};
pub(crate) mod seal {
pub trait ConnectedComponentOutput: Sized + Copy + bytemuck::Pod + num::Zero {
fn as_cv_type() -> i32 {
crate::type_depth::<Self>()
}
}
impl ConnectedComponentOutput for i32 {}
impl ConnectedComponentOutput for u16 {}
}
pub trait NdCvConnectedComponents<T> {
fn connected_components<O: seal::ConnectedComponentOutput>(
&self,
connectivity: Connectivity,
) -> Result<ndarray::Array2<O>, NdCvError>;
fn connected_components_with_stats<O: seal::ConnectedComponentOutput>(
&self,
connectivity: Connectivity,
) -> Result<ConnectedComponentStats<O>, NdCvError>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Connectivity {
Four = 4,
#[default]
Eight = 8,
}
#[derive(Debug, Clone)]
pub struct ConnectedComponentStats<O: seal::ConnectedComponentOutput> {
pub num_labels: i32,
pub labels: ndarray::Array2<O>,
pub stats: ndarray::Array2<i32>,
pub centroids: ndarray::Array2<f64>,
}
// use crate::conversions::NdCvConversionRef;
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>> NdCvConnectedComponents<T>
for ndarray::ArrayBase<S, ndarray::Ix2>
where
ndarray::Array2<T>: NdAsImage<T, ndarray::Ix2>,
{
fn connected_components<O: seal::ConnectedComponentOutput>(
&self,
connectivity: Connectivity,
) -> Result<ndarray::Array2<O>, NdCvError> {
let mat = self.as_image_mat()?;
let mut labels = ndarray::Array2::<O>::zeros(self.dim());
let mut cv_labels = labels.as_image_mat_mut()?;
opencv::imgproc::connected_components(
mat.as_ref(),
cv_labels.as_mut(),
connectivity as i32,
O::as_cv_type(),
)
.change_context(NdCvError)?;
Ok(labels)
}
fn connected_components_with_stats<O: seal::ConnectedComponentOutput>(
&self,
connectivity: Connectivity,
) -> Result<ConnectedComponentStats<O>, NdCvError> {
let mut labels = ndarray::Array2::<O>::zeros(self.dim());
let mut stats = opencv::core::Mat::default();
let mut centroids = opencv::core::Mat::default();
let num_labels = opencv::imgproc::connected_components_with_stats(
self.as_image_mat()?.as_ref(),
labels.as_image_mat_mut()?.as_mut(),
&mut stats,
&mut centroids,
connectivity as i32,
O::as_cv_type(),
)
.change_context(NdCvError)?;
let stats = stats.as_ndarray()?.to_owned();
let centroids = centroids.as_ndarray()?.to_owned();
Ok(ConnectedComponentStats {
labels,
stats,
centroids,
num_labels,
})
}
}
// #[test]
// fn test_connected_components() {
// use opencv::core::MatTrait as _;
// let mat = opencv::core::Mat::new_nd_with_default(&[10, 10], opencv::core::CV_8UC1, 0.into())
// .expect("failed");
// let roi1 = opencv::core::Rect::new(2, 2, 2, 2);
// let roi2 = opencv::core::Rect::new(6, 6, 3, 3);
// let mut mat1 = opencv::core::Mat::roi(&mat, roi1).expect("failed");
// mat1.set_scalar(1.into()).expect("failed");
// let mut mat2 = opencv::core::Mat::roi(&mat, roi2).expect("failed");
// mat2.set_scalar(1.into()).expect("failed");
// let array2: ndarray::ArrayView2<u8> = mat.as_ndarray().expect("failed");
// let output = array2
// .connected_components::<u16>(Connectivity::Four)
// .expect("failed");
// let expected = {
// let mut expected = ndarray::Array2::zeros((10, 10));
// expected.slice_mut(ndarray::s![2..4, 2..4]).fill(1);
// expected.slice_mut(ndarray::s![6..9, 6..9]).fill(2);
// expected
// };
// assert_eq!(output, expected);
// }

270
ndcv-bridge/src/contours.rs Normal file
View File

@@ -0,0 +1,270 @@
//! <https://docs.rs/opencv/latest/opencv/imgproc/fn.find_contours.html>
#![deny(warnings)]
use crate::conversions::*;
use crate::prelude_::*;
use nalgebra::Point2;
use ndarray::*;
#[repr(C)]
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
pub enum ContourRetrievalMode {
#[default]
External = 0, // RETR_EXTERNAL
List = 1, // RETR_LIST
CComp = 2, // RETR_CCOMP
Tree = 3, // RETR_TREE
FloodFill = 4, // RETR_FLOODFILL
}
#[repr(C)]
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
pub enum ContourApproximationMethod {
#[default]
None = 1, // CHAIN_APPROX_NONE
Simple = 2, // CHAIN_APPROX_SIMPLE
Tc89L1 = 3, // CHAIN_APPROX_TC89_L1
Tc89Kcos = 4, // CHAIN_APPROX_TC89_KCOS
}
#[derive(Debug, Clone)]
pub struct ContourHierarchy {
pub next: i32,
pub previous: i32,
pub first_child: i32,
pub parent: i32,
}
#[derive(Debug, Clone)]
pub struct ContourResult {
pub contours: Vec<Vec<Point2<i32>>>,
pub hierarchy: Vec<ContourHierarchy>,
}
mod seal {
pub trait Sealed {}
impl Sealed for u8 {}
}
pub trait NdCvFindContours<T: bytemuck::Pod + seal::Sealed>:
crate::image::NdImage + crate::conversions::NdAsImage<T, ndarray::Ix2>
{
fn find_contours(
&self,
mode: ContourRetrievalMode,
method: ContourApproximationMethod,
) -> Result<Vec<Vec<Point2<i32>>>, NdCvError>;
fn find_contours_with_hierarchy(
&self,
mode: ContourRetrievalMode,
method: ContourApproximationMethod,
) -> Result<ContourResult, NdCvError>;
fn find_contours_def(&self) -> Result<Vec<Vec<Point2<i32>>>, NdCvError> {
self.find_contours(
ContourRetrievalMode::External,
ContourApproximationMethod::Simple,
)
}
fn find_contours_with_hierarchy_def(&self) -> Result<ContourResult, NdCvError> {
self.find_contours_with_hierarchy(
ContourRetrievalMode::External,
ContourApproximationMethod::Simple,
)
}
}
pub trait NdCvContourArea<T: bytemuck::Pod> {
fn contours_area(&self, oriented: bool) -> Result<f64, NdCvError>;
fn contours_area_def(&self) -> Result<f64, NdCvError> {
self.contours_area(false)
}
}
impl<T: ndarray::RawData + ndarray::Data<Elem = u8>> NdCvFindContours<u8> for ArrayBase<T, Ix2> {
fn find_contours(
&self,
mode: ContourRetrievalMode,
method: ContourApproximationMethod,
) -> Result<Vec<Vec<Point2<i32>>>, NdCvError> {
let cv_self = self.as_image_mat()?;
let mut contours = opencv::core::Vector::<opencv::core::Vector<opencv::core::Point>>::new();
opencv::imgproc::find_contours(
&*cv_self,
&mut contours,
mode as i32,
method as i32,
opencv::core::Point::new(0, 0),
)
.change_context(NdCvError)
.attach_printable("Failed to find contours")?;
let mut result: Vec<Vec<Point2<i32>>> = Vec::new();
for i in 0..contours.len() {
let contour = contours.get(i).change_context(NdCvError)?;
let points: Vec<Point2<i32>> =
contour.iter().map(|pt| Point2::new(pt.x, pt.y)).collect();
result.push(points);
}
Ok(result)
}
fn find_contours_with_hierarchy(
&self,
mode: ContourRetrievalMode,
method: ContourApproximationMethod,
) -> Result<ContourResult, NdCvError> {
let cv_self = self.as_image_mat()?;
let mut contours = opencv::core::Vector::<opencv::core::Vector<opencv::core::Point>>::new();
let mut hierarchy = opencv::core::Vector::<opencv::core::Vec4i>::new();
opencv::imgproc::find_contours_with_hierarchy(
&*cv_self,
&mut contours,
&mut hierarchy,
mode as i32,
method as i32,
opencv::core::Point::new(0, 0),
)
.change_context(NdCvError)
.attach_printable("Failed to find contours with hierarchy")?;
let mut contour_list: Vec<Vec<Point2<i32>>> = Vec::new();
for i in 0..contours.len() {
let contour = contours.get(i).change_context(NdCvError)?;
let points: Vec<Point2<i32>> =
contour.iter().map(|pt| Point2::new(pt.x, pt.y)).collect();
contour_list.push(points);
}
let mut hierarchy_list = Vec::new();
for i in 0..hierarchy.len() {
let h = hierarchy.get(i).change_context(NdCvError)?;
hierarchy_list.push(ContourHierarchy {
next: h[0],
previous: h[1],
first_child: h[2],
parent: h[3],
});
}
Ok(ContourResult {
contours: contour_list,
hierarchy: hierarchy_list,
})
}
}
impl<T> NdCvContourArea<T> for Vec<Point2<T>>
where
T: bytemuck::Pod + num::traits::AsPrimitive<i32> + std::cmp::PartialEq + std::fmt::Debug + Copy,
{
fn contours_area(&self, oriented: bool) -> Result<f64, NdCvError> {
if self.is_empty() {
return Ok(0.0);
}
let mut cv_contour: opencv::core::Vector<opencv::core::Point> = opencv::core::Vector::new();
self.iter().for_each(|point| {
cv_contour.push(opencv::core::Point::new(
point.coords[0].as_(),
point.coords[1].as_(),
));
});
opencv::imgproc::contour_area(&cv_contour, oriented)
.change_context(NdCvError)
.attach_printable("Failed to calculate contour area")
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn simple_binary_rect_image() -> Array2<u8> {
let mut img = Array2::<u8>::zeros((10, 10));
for i in 2..8 {
for j in 3..7 {
img[(i, j)] = 255;
}
}
img
}
#[test]
fn test_find_contours_external_simple() {
let img = simple_binary_rect_image();
let contours = img
.find_contours(
ContourRetrievalMode::External,
ContourApproximationMethod::Simple,
)
.expect("Failed to find contours");
assert_eq!(contours.len(), 1);
assert!(contours[0].len() >= 4);
}
#[test]
fn test_find_contours_with_hierarchy() {
let img = simple_binary_rect_image();
let res = img
.find_contours_with_hierarchy(
ContourRetrievalMode::External,
ContourApproximationMethod::Simple,
)
.expect("Failed to find contours with hierarchy");
assert_eq!(res.contours.len(), 1);
assert_eq!(res.hierarchy.len(), 1);
let h = &res.hierarchy[0];
assert_eq!(h.parent, -1);
assert_eq!(h.first_child, -1);
}
#[test]
fn test_default_methods() {
let img = simple_binary_rect_image();
let contours = img.find_contours_def().unwrap();
let res = img.find_contours_with_hierarchy_def().unwrap();
assert_eq!(contours.len(), 1);
assert_eq!(res.contours.len(), 1);
}
#[test]
fn test_contour_area_calculation() {
let img = simple_binary_rect_image();
let contours = img.find_contours_def().unwrap();
let expected_area = 15.;
let area = contours[0].contours_area_def().unwrap();
assert!(
(area - expected_area).abs() < 1.0,
"Area mismatch: got {area}, expected {expected_area}",
);
}
#[test]
fn test_empty_input_returns_no_contours() {
let img = Array2::<u8>::zeros((10, 10));
let contours = img.find_contours_def().unwrap();
assert!(contours.is_empty());
let res = img.find_contours_with_hierarchy_def().unwrap();
assert!(res.contours.is_empty());
assert!(res.hierarchy.is_empty());
}
#[test]
fn test_contour_area_empty_contour() {
let contour: Vec<Point2<i32>> = vec![];
let area = contour.contours_area_def().unwrap();
assert_eq!(area, 0.0);
}
}

View File

@@ -0,0 +1,337 @@
//! Mat <--> ndarray conversion traits
//!
//! Conversion Table
//!
//! | ndarray | Mat |
//! |--------- |----- |
//! | Array<T, Ix1> | Mat(ndims = 1, channels = 1) |
//! | Array<T, Ix2> | Mat(ndims = 2, channels = 1) |
//! | Array<T, Ix2> | Mat(ndims = 1, channels = X) |
//! | Array<T, Ix3> | Mat(ndims = 3, channels = 1) |
//! | Array<T, Ix3> | Mat(ndims = 2, channels = X) |
//! | Array<T, Ix4> | Mat(ndims = 4, channels = 1) |
//! | Array<T, Ix4> | Mat(ndims = 3, channels = X) |
//! | Array<T, Ix5> | Mat(ndims = 5, channels = 1) |
//! | Array<T, Ix5> | Mat(ndims = 4, channels = X) |
//! | Array<T, Ix6> | Mat(ndims = 6, channels = 1) |
//! | Array<T, Ix6> | Mat(ndims = 5, channels = X) |
//!
//! // X is the last dimension
use crate::NdCvError;
use crate::type_depth;
use error_stack::*;
use ndarray::{Ix2, Ix3};
use opencv::core::MatTraitConst;
mod impls;
pub(crate) mod matref;
use matref::{MatRef, MatRefMut};
pub(crate) mod seal {
pub trait SealedInternal {}
impl<T, S: ndarray::Data<Elem = T>, D> SealedInternal for ndarray::ArrayBase<S, D> {}
// impl<T, S: ndarray::DataMut<Elem = T>, D> SealedInternal for ndarray::ArrayBase<S, D> {}
}
pub trait NdCvConversion<T: bytemuck::Pod + Copy, D: ndarray::Dimension>:
seal::SealedInternal + Sized
{
fn to_mat(&self) -> Result<opencv::core::Mat, NdCvError>;
fn from_mat(
mat: opencv::core::Mat,
) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<T>, D>, NdCvError>;
}
impl<T: bytemuck::Pod + Copy, S: ndarray::Data<Elem = T>, D: ndarray::Dimension>
NdCvConversion<T, D> for ndarray::ArrayBase<S, D>
where
Self: NdAsImage<T, D>,
{
fn to_mat(&self) -> Result<opencv::core::Mat, NdCvError> {
Ok(self.as_image_mat()?.mat.clone())
}
fn from_mat(
mat: opencv::core::Mat,
) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<T>, D>, NdCvError> {
let ndarray = unsafe { impls::mat_to_ndarray::<T, D>(&mat) }.change_context(NdCvError)?;
Ok(ndarray.to_owned())
}
}
pub trait MatAsNd {
fn as_ndarray<T: bytemuck::Pod, D: ndarray::Dimension>(
&self,
) -> Result<ndarray::ArrayView<T, D>, NdCvError>;
}
impl MatAsNd for opencv::core::Mat {
fn as_ndarray<T: bytemuck::Pod, D: ndarray::Dimension>(
&self,
) -> Result<ndarray::ArrayView<T, D>, NdCvError> {
unsafe { impls::mat_to_ndarray::<T, D>(self) }.change_context(NdCvError)
}
}
pub trait NdAsMat<T: bytemuck::Pod + Copy, D: ndarray::Dimension> {
fn as_single_channel_mat(&self) -> Result<MatRef, NdCvError>;
fn as_multi_channel_mat(&self) -> Result<MatRef, NdCvError>;
}
pub trait NdAsMatMut<T: bytemuck::Pod + Copy, D: ndarray::Dimension>: NdAsMat<T, D> {
fn as_single_channel_mat_mut(&mut self) -> Result<MatRefMut, NdCvError>;
fn as_multi_channel_mat_mut(&mut self) -> Result<MatRefMut, NdCvError>;
}
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>, D: ndarray::Dimension> NdAsMat<T, D>
for ndarray::ArrayBase<S, D>
{
fn as_single_channel_mat(&self) -> Result<MatRef, NdCvError> {
let mat = unsafe { impls::ndarray_to_mat_regular(self) }.change_context(NdCvError)?;
Ok(MatRef::new(mat))
}
fn as_multi_channel_mat(&self) -> Result<MatRef, NdCvError> {
let mat = unsafe { impls::ndarray_to_mat_consolidated(self) }.change_context(NdCvError)?;
Ok(MatRef::new(mat))
}
}
impl<T: bytemuck::Pod, S: ndarray::DataMut<Elem = T>, D: ndarray::Dimension> NdAsMatMut<T, D>
for ndarray::ArrayBase<S, D>
{
fn as_single_channel_mat_mut(&mut self) -> Result<MatRefMut, NdCvError> {
let mat = unsafe { impls::ndarray_to_mat_regular(self) }.change_context(NdCvError)?;
Ok(MatRefMut::new(mat))
}
fn as_multi_channel_mat_mut(&mut self) -> Result<MatRefMut, NdCvError> {
let mat = unsafe { impls::ndarray_to_mat_consolidated(self) }.change_context(NdCvError)?;
Ok(MatRefMut::new(mat))
}
}
pub trait NdAsImage<T: bytemuck::Pod, D: ndarray::Dimension> {
fn as_image_mat(&self) -> Result<MatRef, NdCvError>;
}
pub trait NdAsImageMut<T: bytemuck::Pod, D: ndarray::Dimension> {
fn as_image_mat_mut(&mut self) -> Result<MatRefMut, NdCvError>;
}
impl<T, S> NdAsImage<T, Ix2> for ndarray::ArrayBase<S, Ix2>
where
T: bytemuck::Pod + Copy,
S: ndarray::Data<Elem = T>,
{
fn as_image_mat(&self) -> Result<MatRef, NdCvError> {
self.as_single_channel_mat()
}
}
impl<T, S> NdAsImageMut<T, Ix2> for ndarray::ArrayBase<S, Ix2>
where
T: bytemuck::Pod + Copy,
S: ndarray::DataMut<Elem = T>,
{
fn as_image_mat_mut(&mut self) -> Result<MatRefMut, NdCvError> {
self.as_single_channel_mat_mut()
}
}
impl<T, S> NdAsImage<T, Ix3> for ndarray::ArrayBase<S, Ix3>
where
T: bytemuck::Pod + Copy,
S: ndarray::Data<Elem = T>,
{
fn as_image_mat(&self) -> Result<MatRef, NdCvError> {
self.as_multi_channel_mat()
}
}
impl<T, S> NdAsImageMut<T, Ix3> for ndarray::ArrayBase<S, Ix3>
where
T: bytemuck::Pod + Copy,
S: ndarray::DataMut<Elem = T>,
{
fn as_image_mat_mut(&mut self) -> Result<MatRefMut, NdCvError> {
self.as_multi_channel_mat_mut()
}
}
// #[test]
// fn test_1d_mat_to_ndarray() {
// let mat = opencv::core::Mat::new_nd_with_default(
// &[10],
// opencv::core::CV_MAKE_TYPE(opencv::core::CV_8U, 1),
// 200.into(),
// )
// .expect("failed");
// let array: ndarray::ArrayView1<u8> = mat.as_ndarray().expect("failed");
// array.into_iter().for_each(|&x| assert_eq!(x, 200));
// }
// #[test]
// fn test_2d_mat_to_ndarray() {
// let mat = opencv::core::Mat::new_nd_with_default(
// &[10],
// opencv::core::CV_16SC3,
// (200, 200, 200).into(),
// )
// .expect("failed");
// let array2: ndarray::ArrayView2<i16> = mat.as_ndarray().expect("failed");
// assert_eq!(array2.shape(), [10, 3]);
// array2.into_iter().for_each(|&x| {
// assert_eq!(x, 200);
// });
// }
// #[test]
// fn test_3d_mat_to_ndarray() {
// let mat = opencv::core::Mat::new_nd_with_default(
// &[20, 30],
// opencv::core::CV_32FC3,
// (200, 200, 200).into(),
// )
// .expect("failed");
// let array2: ndarray::ArrayView3<f32> = mat.as_ndarray().expect("failed");
// array2.into_iter().for_each(|&x| {
// assert_eq!(x, 200f32);
// });
// }
// #[test]
// fn test_mat_to_dyn_ndarray() {
// let mat = opencv::core::Mat::new_nd_with_default(&[10], opencv::core::CV_8UC1, 200.into())
// .expect("failed");
// let array2: ndarray::ArrayViewD<u8> = mat.as_ndarray().expect("failed");
// array2.into_iter().for_each(|&x| assert_eq!(x, 200));
// }
// #[test]
// fn test_3d_mat_to_ndarray_4k() {
// let mat = opencv::core::Mat::new_nd_with_default(
// &[4096, 4096],
// opencv::core::CV_8UC3,
// (255, 0, 255).into(),
// )
// .expect("failed");
// let array2: ndarray::ArrayView3<u8> = (mat).as_ndarray().expect("failed");
// array2.exact_chunks((1, 1, 3)).into_iter().for_each(|x| {
// assert_eq!(x[(0, 0, 0)], 255);
// assert_eq!(x[(0, 0, 1)], 0);
// assert_eq!(x[(0, 0, 2)], 255);
// });
// }
// // #[test]
// // fn test_3d_mat_to_ndarray_8k() {
// // let mat = opencv::core::Mat::new_nd_with_default(
// // &[8192, 8192],
// // opencv::core::CV_8UC3,
// // (255, 0, 255).into(),
// // )
// // .expect("failed");
// // let array2 = ndarray::Array3::<u8>::from_mat(mat).expect("failed");
// // array2.exact_chunks((1, 1, 3)).into_iter().for_each(|x| {
// // assert_eq!(x[(0, 0, 0)], 255);
// // assert_eq!(x[(0, 0, 1)], 0);
// // assert_eq!(x[(0, 0, 2)], 255);
// // });
// // }
// #[test]
// pub fn test_mat_to_nd_default_strides() {
// let mat = opencv::core::Mat::new_rows_cols_with_default(
// 10,
// 10,
// opencv::core::CV_8UC3,
// opencv::core::VecN([10f64, 0.0, 0.0, 0.0]),
// )
// .expect("failed");
// let array = unsafe { impls::mat_to_ndarray::<u8, Ix3>(&mat) }.expect("failed");
// assert_eq!(array.shape(), [10, 10, 3]);
// assert_eq!(array.strides(), [30, 3, 1]);
// assert_eq!(array[(0, 0, 0)], 10);
// }
// #[test]
// pub fn test_mat_to_nd_custom_strides() {
// let mat = opencv::core::Mat::new_rows_cols_with_default(
// 10,
// 10,
// opencv::core::CV_8UC3,
// opencv::core::VecN([10f64, 0.0, 0.0, 0.0]),
// )
// .unwrap();
// let mat_roi = opencv::core::Mat::roi(&mat, opencv::core::Rect::new(3, 2, 3, 5))
// .expect("failed to get roi");
// let array = unsafe { impls::mat_to_ndarray::<u8, Ix3>(&mat_roi) }.expect("failed");
// assert_eq!(array.shape(), [5, 3, 3]);
// assert_eq!(array.strides(), [30, 3, 1]);
// assert_eq!(array[(0, 0, 0)], 10);
// }
// #[test]
// pub fn test_non_continuous_3d() {
// let array = ndarray::Array3::<f32>::from_shape_fn((10, 10, 4), |(i, j, k)| {
// ((i + 1) * (j + 1) * (k + 1)) as f32
// });
// let slice = array.slice(ndarray::s![3..7, 3..7, 0..4]);
// let mat = unsafe { impls::ndarray_to_mat_consolidated(&slice) }.unwrap();
// let arr = unsafe { impls::mat_to_ndarray::<f32, Ix3>(&mat).unwrap() };
// assert!(slice == arr);
// }
// #[test]
// pub fn test_5d_array() {
// let array = ndarray::Array5::<f32>::ones((1, 2, 3, 4, 5));
// let mat = unsafe { impls::ndarray_to_mat_consolidated(&array) }.unwrap();
// let arr = unsafe { impls::mat_to_ndarray::<f32, ndarray::Ix5>(&mat).unwrap() };
// assert_eq!(array, arr);
// }
// #[test]
// pub fn test_3d_array() {
// let array = ndarray::Array3::<f32>::ones((23, 31, 33));
// let mat = unsafe { impls::ndarray_to_mat_consolidated(&array) }.unwrap();
// let arr = unsafe { impls::mat_to_ndarray::<f32, ndarray::Ix3>(&mat).unwrap() };
// assert_eq!(array, arr);
// }
// #[test]
// pub fn test_2d_array() {
// let array = ndarray::Array2::<f32>::ones((23, 31));
// let mat = unsafe { impls::ndarray_to_mat_consolidated(&array) }.unwrap();
// let arr = unsafe { impls::mat_to_ndarray::<f32, ndarray::Ix2>(&mat).unwrap() };
// assert_eq!(array, arr);
// }
// #[test]
// #[should_panic]
// pub fn test_1d_array_consolidated() {
// let array = ndarray::Array1::<f32>::ones(23);
// let mat = unsafe { impls::ndarray_to_mat_consolidated(&array) }.unwrap();
// let arr = unsafe { impls::mat_to_ndarray::<f32, ndarray::Ix1>(&mat).unwrap() };
// assert_eq!(array, arr);
// }
// #[test]
// pub fn test_1d_array_regular() {
// let array = ndarray::Array1::<f32>::ones(23);
// let mat = unsafe { impls::ndarray_to_mat_regular(&array) }.unwrap();
// let arr = unsafe { impls::mat_to_ndarray::<f32, ndarray::Ix1>(&mat).unwrap() };
// assert_eq!(array, arr);
// }
// #[test]
// pub fn test_2d_array_regular() {
// let array = ndarray::Array2::<f32>::ones((23, 31));
// let mat = unsafe { impls::ndarray_to_mat_regular(&array) }.unwrap();
// let arr = unsafe { impls::mat_to_ndarray::<f32, ndarray::Ix2>(&mat).unwrap() };
// assert_eq!(array, arr);
// }
// #[test]
// pub fn test_ndcv_1024_1024_to_mat() {
// let array = ndarray::Array2::<f32>::ones((1024, 1024));
// let _mat = array.to_mat().unwrap();
// }

View File

@@ -0,0 +1,168 @@
use super::*;
use core::ffi::*;
use opencv::core::prelude::*;
pub(crate) unsafe fn ndarray_to_mat_regular<
T,
S: ndarray::Data<Elem = T>,
D: ndarray::Dimension,
>(
input: &ndarray::ArrayBase<S, D>,
) -> Result<opencv::core::Mat, NdCvError> {
let shape = input.shape();
let strides = input.strides();
// let channels = shape.last().copied().unwrap_or(1);
// if channels > opencv::core::CV_CN_MAX as usize {
// Err(Report::new(NdCvError).attach_printable(format!(
// "Number of channels({channels}) exceeds CV_CN_MAX({}) use the regular version of the function", opencv::core::CV_CN_MAX
// )))?;
// }
// let size_len = shape.len();
let size = shape.iter().copied().map(|f| f as i32).collect::<Vec<_>>();
// Step len for ndarray is always 1 less than ndims
let step_len = strides.len() - 1;
let step = strides
.iter()
.take(step_len)
.copied()
.map(|f| f as usize * core::mem::size_of::<T>())
.collect::<Vec<_>>();
let data_ptr = input.as_ptr() as *const c_void;
let typ = opencv::core::CV_MAKETYPE(type_depth::<T>(), 1);
let mat = opencv::core::Mat::new_nd_with_data_unsafe(
size.as_slice(),
typ,
data_ptr.cast_mut(),
Some(step.as_slice()),
)
.change_context(NdCvError)?;
Ok(mat)
}
pub(crate) unsafe fn ndarray_to_mat_consolidated<
T,
S: ndarray::Data<Elem = T>,
D: ndarray::Dimension,
>(
input: &ndarray::ArrayBase<S, D>,
) -> Result<opencv::core::Mat, NdCvError> {
let shape = input.shape();
let strides = input.strides();
let channels = shape.last().copied().unwrap_or(1);
if channels > opencv::core::CV_CN_MAX as usize {
Err(Report::new(NdCvError).attach_printable(format!(
"Number of channels({channels}) exceeds CV_CN_MAX({}) use the regular version of the function", opencv::core::CV_CN_MAX
)))?;
}
if shape.len() > 2 {
// Basically the second last stride is used to jump from one column to next
// But opencv only keeps ndims - 1 strides so we can't have the column stride as that
// will be lost
if shape.last() != strides.get(strides.len() - 2).map(|x| *x as usize).as_ref() {
Err(Report::new(NdCvError).attach_printable(
"You cannot slice into the last axis in ndarray when converting to mat",
))?;
}
} else if shape.len() == 1 {
return Err(Report::new(NdCvError).attach_printable(
"You cannot convert a 1D array to a Mat while using the consolidated version",
));
}
// Since this is the consolidated version we should always only have ndims - 1 sizes and
// ndims - 2 strides
let size_len = shape.len() - 1; // Since we move last axis into the channel
let size = shape
.iter()
.take(size_len)
.map(|f| *f as i32)
.collect::<Vec<_>>();
let step_len = strides.len() - 1;
let step = strides
.iter()
.take(step_len)
.map(|f| *f as usize * core::mem::size_of::<T>())
.collect::<Vec<_>>();
let data_ptr = input.as_ptr() as *const c_void;
let typ = opencv::core::CV_MAKETYPE(type_depth::<T>(), channels as i32);
let mat = opencv::core::Mat::new_nd_with_data_unsafe(
size.as_slice(),
typ,
data_ptr.cast_mut(),
Some(step.as_slice()),
)
.change_context(NdCvError)?;
Ok(mat)
}
pub(crate) unsafe fn mat_to_ndarray<T: bytemuck::Pod, D: ndarray::Dimension>(
mat: &opencv::core::Mat,
) -> Result<ndarray::ArrayView<'_, T, D>, NdCvError> {
let depth = mat.depth();
if type_depth::<T>() != depth {
return Err(Report::new(NdCvError).attach_printable(format!(
"Expected type Mat<{}> ({}), got Mat<{}> ({})",
std::any::type_name::<T>(),
type_depth::<T>(),
crate::depth_type(depth),
depth,
)));
}
// Since a dims always returns >= 2 we can't use this to check if it's a 1D array
// So we compare the first axis to the total to see if its a 1D array
let is_1d = mat.total() as i32 == mat.rows();
let dims = is_1d.then_some(1).unwrap_or(mat.dims());
let channels = mat.channels();
let ndarray_size = (channels != 1).then_some(dims + 1).unwrap_or(dims) as usize;
if let Some(ndim) = D::NDIM {
// When channels is not 1,
// the last dimension is the channels
// Array1 -> Mat(ndims = 1, channels = 1)
// Array2 -> Mat(ndims = 1, channels = X)
// Array2 -> Mat(ndims = 2, channels = 1)
// Array3 -> Mat(ndims = 2, channels = X)
// Array3 -> Mat(ndims = 3, channels = 1)
// ...
if ndim != dims as usize && channels == 1 {
return Err(Report::new(NdCvError)
.attach_printable(format!("Expected {}D array, got {}D", ndim, ndarray_size)));
}
}
let mat_size = mat.mat_size();
let sizes = (0..dims)
.map(|i| mat_size.get(i).change_context(NdCvError))
.chain([Ok(channels)])
.map(|x| x.map(|x| x as usize))
.take(ndarray_size)
.collect::<Result<Vec<_>, NdCvError>>()
.change_context(NdCvError)?;
let strides = (0..(dims - 1))
.map(|i| mat.step1(i).change_context(NdCvError))
.chain([
Ok(channels as usize),
Ok((channels == 1).then_some(0).unwrap_or(1)),
])
.take(ndarray_size)
.collect::<Result<Vec<_>, NdCvError>>()
.change_context(NdCvError)?;
use ndarray::ShapeBuilder;
let shape = sizes.strides(strides);
let raw_array = ndarray::RawArrayView::from_shape_ptr(shape, mat.data() as *const T)
.into_dimensionality()
.change_context(NdCvError)?;
Ok(unsafe { raw_array.deref_into_view() })
}

View File

@@ -0,0 +1,73 @@
#[derive(Debug, Clone)]
pub struct MatRef<'a> {
pub(crate) mat: opencv::core::Mat,
pub(crate) _marker: core::marker::PhantomData<&'a ()>,
}
impl MatRef<'_> {
pub fn clone_pointee(&self) -> opencv::core::Mat {
self.mat.clone()
}
}
impl MatRef<'_> {
pub fn new<'a>(mat: opencv::core::Mat) -> MatRef<'a> {
MatRef {
mat,
_marker: core::marker::PhantomData,
}
}
}
impl AsRef<opencv::core::Mat> for MatRef<'_> {
fn as_ref(&self) -> &opencv::core::Mat {
&self.mat
}
}
impl AsRef<opencv::core::Mat> for MatRefMut<'_> {
fn as_ref(&self) -> &opencv::core::Mat {
&self.mat
}
}
impl AsMut<opencv::core::Mat> for MatRefMut<'_> {
fn as_mut(&mut self) -> &mut opencv::core::Mat {
&mut self.mat
}
}
#[derive(Debug, Clone)]
pub struct MatRefMut<'a> {
pub(crate) mat: opencv::core::Mat,
pub(crate) _marker: core::marker::PhantomData<&'a mut ()>,
}
impl MatRefMut<'_> {
pub fn new<'a>(mat: opencv::core::Mat) -> MatRefMut<'a> {
MatRefMut {
mat,
_marker: core::marker::PhantomData,
}
}
}
impl core::ops::Deref for MatRef<'_> {
type Target = opencv::core::Mat;
fn deref(&self) -> &Self::Target {
&self.mat
}
}
impl core::ops::Deref for MatRefMut<'_> {
type Target = opencv::core::Mat;
fn deref(&self) -> &Self::Target {
&self.mat
}
}
impl core::ops::DerefMut for MatRefMut<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.mat
}
}

262
ndcv-bridge/src/fir.rs Normal file
View File

@@ -0,0 +1,262 @@
use error_stack::*;
use fast_image_resize::*;
use images::{Image, ImageRef};
#[derive(Debug, Clone, thiserror::Error)]
#[error("NdFirError")]
pub struct NdFirError;
type Result<T, E = Report<NdFirError>> = std::result::Result<T, E>;
pub trait NdAsImage<T: seal::Sealed, D: ndarray::Dimension>: Sized {
fn as_image_ref(&self) -> Result<ImageRef<'_>>;
}
pub trait NdAsImageMut<T: seal::Sealed, D: ndarray::Dimension>: Sized {
fn as_image_ref_mut(&mut self) -> Result<Image<'_>>;
}
pub struct NdarrayImageContainer<'a, T: seal::Sealed, D: ndarray::Dimension> {
#[allow(dead_code)]
data: ndarray::ArrayView<'a, T, D>,
pub _phantom: std::marker::PhantomData<(T, D)>,
}
impl<'a, T: seal::Sealed> NdarrayImageContainer<'a, T, ndarray::Ix3> {
pub fn new<S: ndarray::Data<Elem = T>>(array: &'a ndarray::ArrayBase<S, ndarray::Ix3>) -> Self {
Self {
data: array.view(),
_phantom: std::marker::PhantomData,
}
}
}
impl<'a, T: seal::Sealed> NdarrayImageContainer<'a, T, ndarray::Ix2> {
pub fn new<S: ndarray::Data<Elem = T>>(array: &'a ndarray::ArrayBase<S, ndarray::Ix2>) -> Self {
Self {
data: array.view(),
_phantom: std::marker::PhantomData,
}
}
}
pub struct NdarrayImageContainerMut<'a, T: seal::Sealed, D: ndarray::Dimension> {
#[allow(dead_code)]
data: ndarray::ArrayViewMut<'a, T, D>,
}
impl<'a, T: seal::Sealed> NdarrayImageContainerMut<'a, T, ndarray::Ix3> {
pub fn new<S: ndarray::DataMut<Elem = T>>(
array: &'a mut ndarray::ArrayBase<S, ndarray::Ix3>,
) -> Self {
Self {
data: array.view_mut(),
}
}
}
impl<'a, T: seal::Sealed> NdarrayImageContainerMut<'a, T, ndarray::Ix2> {
pub fn new<S: ndarray::DataMut<Elem = T>>(
array: &'a mut ndarray::ArrayBase<S, ndarray::Ix2>,
) -> Self {
Self {
data: array.view_mut(),
}
}
}
pub struct NdarrayImageContainerTyped<'a, T: seal::Sealed, D: ndarray::Dimension, P: PixelTrait> {
#[allow(dead_code)]
data: ndarray::ArrayView<'a, T, D>,
__marker: std::marker::PhantomData<P>,
}
// unsafe impl<'a, T: seal::Sealed + Sync + InnerPixel, P: PixelTrait> ImageView
// for NdarrayImageContainerTyped<'a, T, ndarray::Ix3, P>
// where
// T: bytemuck::Pod,
// {
// type Pixel = P;
// fn width(&self) -> u32 {
// self.data.shape()[1] as u32
// }
// fn height(&self) -> u32 {
// self.data.shape()[0] as u32
// }
// fn iter_rows(&self, start_row: u32) -> impl Iterator<Item = &[Self::Pixel]> {
// self.data
// .rows()
// .into_iter()
// .skip(start_row as usize)
// .map(|row| {
// row.as_slice()
// .unwrap_or_default()
// .chunks_exact(P::CHANNELS as usize)
// })
// }
// }
// impl<'a, T: fast_image_resize::pixels::InnerPixel + seal::Sealed, D: ndarray::Dimension>
// fast_image_resize::IntoImageView for NdarrayImageContainer<'a, T, D>
// {
// fn pixel_type(&self) -> Option<PixelType> {
// match D::NDIM {
// Some(2) => Some(to_pixel_type::<T>(1).expect("Failed to convert to pixel type")),
// Some(3) => Some(
// to_pixel_type::<T>(self.data.shape()[2]).expect("Failed to convert to pixel type"),
// ),
// _ => None,
// }
// }
// fn width(&self) -> u32 {
// self.data.shape()[1] as u32
// }
// fn height(&self) -> u32 {
// self.data.shape()[0] as u32
// }
// fn image_view<P: PixelTrait>(&'a self) -> Option<NdarrayImageContainerTyped<'a, T, D, P>> {
// Some(NdarrayImageContainerTyped {
// data: self.data.view(),
// __marker: std::marker::PhantomData,
// })
// }
// }
pub fn to_pixel_type<T: seal::Sealed>(u: usize) -> Result<PixelType> {
match (core::any::type_name::<T>(), u) {
("u8", 1) => Ok(PixelType::U8),
("u8", 2) => Ok(PixelType::U8x2),
("u8", 3) => Ok(PixelType::U8x3),
("u8", 4) => Ok(PixelType::U8x4),
("u16", 1) => Ok(PixelType::U16),
("i32", 1) => Ok(PixelType::I32),
("f32", 1) => Ok(PixelType::F32),
("f32", 2) => Ok(PixelType::F32x2),
("f32", 3) => Ok(PixelType::F32x3),
("f32", 4) => Ok(PixelType::F32x4),
_ => Err(Report::new(NdFirError).attach_printable("Unsupported pixel type")),
}
}
mod seal {
pub trait Sealed {}
impl Sealed for u8 {}
impl Sealed for u16 {}
impl Sealed for i32 {}
impl Sealed for f32 {}
}
impl<S: ndarray::Data<Elem = T>, T: seal::Sealed + bytemuck::Pod, D: ndarray::Dimension>
NdAsImage<T, D> for ndarray::ArrayBase<S, D>
{
/// Clones self and makes a new image
fn as_image_ref(&self) -> Result<ImageRef> {
let shape = self.shape();
let rows = *shape
.first()
.ok_or_else(|| Report::new(NdFirError).attach_printable("Failed to get rows"))?
as u32;
let cols = *shape.get(1).unwrap_or(&1) as u32;
let channels = *shape.get(2).unwrap_or(&1);
let data = self
.as_slice()
.ok_or(NdFirError)
.attach_printable("The ndarray is non continuous")?;
let data_bytes: &[u8] = bytemuck::cast_slice(data);
let pixel_type = to_pixel_type::<T>(channels)?;
ImageRef::new(cols, rows, data_bytes, pixel_type)
.change_context(NdFirError)
.attach_printable("Failed to create Image from ndarray")
}
}
impl<S: ndarray::DataMut<Elem = T>, T: seal::Sealed + bytemuck::Pod, D: ndarray::Dimension>
NdAsImageMut<T, D> for ndarray::ArrayBase<S, D>
{
fn as_image_ref_mut(&mut self) -> Result<Image<'_>>
where
S: ndarray::DataMut<Elem = T>,
{
let shape = self.shape();
let rows = *shape
.first()
.ok_or_else(|| Report::new(NdFirError).attach_printable("Failed to get rows"))?
as u32;
let cols = *shape.get(1).unwrap_or(&1) as u32;
let channels = *shape.get(2).unwrap_or(&1);
let data = self
.as_slice_mut()
.ok_or(NdFirError)
.attach_printable("The ndarray is non continuous")?;
let data_bytes: &mut [u8] = bytemuck::cast_slice_mut(data);
let pixel_type = to_pixel_type::<T>(channels)?;
Image::from_slice_u8(cols, rows, data_bytes, pixel_type)
.change_context(NdFirError)
.attach_printable("Failed to create Image from ndarray")
}
}
pub trait NdFir<T, D> {
fn fast_resize<'o>(
&self,
height: usize,
width: usize,
options: impl Into<Option<&'o ResizeOptions>>,
) -> Result<ndarray::Array<T, D>>;
}
impl<T: seal::Sealed + bytemuck::Pod + num::Zero, S: ndarray::Data<Elem = T>> NdFir<T, ndarray::Ix3>
for ndarray::ArrayBase<S, ndarray::Ix3>
{
fn fast_resize<'o>(
&self,
height: usize,
width: usize,
options: impl Into<Option<&'o ResizeOptions>>,
) -> Result<ndarray::Array3<T>> {
let source = self.as_image_ref()?;
let (_height, _width, channels) = self.dim();
let mut dest = ndarray::Array3::<T>::zeros((height, width, channels));
let mut dest_image = dest.as_image_ref_mut()?;
let mut resizer = fast_image_resize::Resizer::default();
resizer
.resize(&source, &mut dest_image, options)
.change_context(NdFirError)?;
Ok(dest)
}
}
impl<T: seal::Sealed + bytemuck::Pod + num::Zero, S: ndarray::Data<Elem = T>> NdFir<T, ndarray::Ix2>
for ndarray::ArrayBase<S, ndarray::Ix2>
{
fn fast_resize<'o>(
&self,
height: usize,
width: usize,
options: impl Into<Option<&'o ResizeOptions>>,
) -> Result<ndarray::Array<T, ndarray::Ix2>> {
let source = self.as_image_ref()?;
let (_height, _width) = self.dim();
let mut dest = ndarray::Array::<T, ndarray::Ix2>::zeros((height, width));
let mut dest_image = dest.as_image_ref_mut()?;
let mut resizer = fast_image_resize::Resizer::default();
resizer
.resize(&source, &mut dest_image, options)
.change_context(NdFirError)?;
Ok(dest)
}
}
#[test]
pub fn test_ndarray_fast_image_resize_u8() {
let source_fhd = ndarray::Array3::<u8>::ones((1920, 1080, 3));
let mut resized_hd = ndarray::Array3::<u8>::zeros((1280, 720, 3));
let mut resizer = fast_image_resize::Resizer::default();
resizer
.resize(
&source_fhd.as_image_ref().unwrap(),
&mut resized_hd.as_image_ref_mut().unwrap(),
None,
)
.unwrap();
assert_eq!(resized_hd.shape(), [1280, 720, 3]);
}

307
ndcv-bridge/src/gaussian.rs Normal file
View File

@@ -0,0 +1,307 @@
//! <https://docs.rs/opencv/latest/opencv/imgproc/fn.gaussian_blur.html>
use crate::conversions::*;
use crate::prelude_::*;
use ndarray::*;
#[repr(C)]
#[derive(Default, Debug, Copy, Clone)]
pub enum BorderType {
#[default]
BorderConstant = 0,
BorderReplicate = 1,
BorderReflect = 2,
BorderWrap = 3,
BorderReflect101 = 4,
BorderTransparent = 5,
BorderIsolated = 16,
}
#[repr(C)]
#[derive(Default, Debug, Copy, Clone)]
pub enum AlgorithmHint {
#[default]
AlgoHintDefault = 0,
AlgoHintAccurate = 1,
AlgoHintApprox = 2,
}
mod seal {
pub trait Sealed {}
// src: input image; the image can have any number of channels, which are processed independently, but the depth should be CV_8U, CV_16U, CV_16S, CV_32F or CV_64F.
impl Sealed for u8 {}
impl Sealed for u16 {}
impl Sealed for i16 {}
impl Sealed for f32 {}
impl Sealed for f64 {}
}
pub trait NdCvGaussianBlur<T: bytemuck::Pod + seal::Sealed, D: ndarray::Dimension>:
crate::image::NdImage + crate::conversions::NdAsImage<T, D>
{
fn gaussian_blur(
&self,
kernel_size: (u8, u8),
sigma_x: f64,
sigma_y: f64,
border_type: BorderType,
) -> Result<ndarray::Array<T, D>, NdCvError>;
fn gaussian_blur_def(
&self,
kernel: (u8, u8),
sigma_x: f64,
) -> Result<ndarray::Array<T, D>, NdCvError> {
self.gaussian_blur(kernel, sigma_x, sigma_x, BorderType::BorderConstant)
}
}
impl<
T: bytemuck::Pod + num::Zero + seal::Sealed,
S: ndarray::RawData + ndarray::Data<Elem = T>,
D: ndarray::Dimension,
> NdCvGaussianBlur<T, D> for ArrayBase<S, D>
where
ndarray::ArrayBase<S, D>: crate::image::NdImage + crate::conversions::NdAsImage<T, D>,
ndarray::Array<T, D>: crate::conversions::NdAsImageMut<T, D>,
{
fn gaussian_blur(
&self,
kernel_size: (u8, u8),
sigma_x: f64,
sigma_y: f64,
border_type: BorderType,
) -> Result<ndarray::Array<T, D>, NdCvError> {
let mut dst = ndarray::Array::zeros(self.dim());
let cv_self = self.as_image_mat()?;
let mut cv_dst = dst.as_image_mat_mut()?;
opencv::imgproc::gaussian_blur(
&*cv_self,
&mut *cv_dst,
opencv::core::Size {
width: kernel_size.0 as i32,
height: kernel_size.1 as i32,
},
sigma_x,
sigma_y,
border_type as i32,
)
.change_context(NdCvError)
.attach_printable("Failed to apply gaussian blur")?;
Ok(dst)
}
}
// impl<
// T: bytemuck::Pod + num::Zero + seal::Sealed,
// S: ndarray::RawData + ndarray::Data<Elem = T>,
// > NdCvGaussianBlur<T, Ix3> for ArrayBase<S, Ix3>
// {
// fn gaussian_blur(
// &self,
// kernel_size: (u8, u8),
// sigma_x: f64,
// sigma_y: f64,
// border_type: BorderType,
// ) -> Result<ndarray::Array<T, Ix3>, NdCvError> {
// let mut dst = ndarray::Array::zeros(self.dim());
// let cv_self = self.as_image_mat()?;
// let mut cv_dst = dst.as_image_mat_mut()?;
// opencv::imgproc::gaussian_blur(
// &*cv_self,
// &mut *cv_dst,
// opencv::core::Size {
// width: kernel_size.0 as i32,
// height: kernel_size.1 as i32,
// },
// sigma_x,
// sigma_y,
// border_type as i32,
// )
// .change_context(NdCvError)
// .attach_printable("Failed to apply gaussian blur")?;
// Ok(dst)
// }
// }
//
// impl<
// T: bytemuck::Pod + num::Zero + seal::Sealed,
// S: ndarray::RawData + ndarray::Data<Elem = T>,
// > NdCvGaussianBlur<T, Ix2> for ArrayBase<S, Ix2>
// {
// fn gaussian_blur(
// &self,
// kernel_size: (u8, u8),
// sigma_x: f64,
// sigma_y: f64,
// border_type: BorderType,
// ) -> Result<ndarray::Array<T, Ix2>, NdCvError> {
// let mut dst = ndarray::Array::zeros(self.dim());
// let cv_self = self.as_image_mat()?;
// let mut cv_dst = dst.as_image_mat_mut()?;
// opencv::imgproc::gaussian_blur(
// &*cv_self,
// &mut *cv_dst,
// opencv::core::Size {
// width: kernel_size.0 as i32,
// height: kernel_size.1 as i32,
// },
// sigma_x,
// sigma_y,
// border_type as i32,
// )
// .change_context(NdCvError)
// .attach_printable("Failed to apply gaussian blur")?;
// Ok(dst)
// }
// }
/// For smaller values it is faster to use the allocated version
/// For example in a 4k f32 image this is about 50% faster than the allocated one
pub trait NdCvGaussianBlurInPlace<T: bytemuck::Pod + seal::Sealed, D: ndarray::Dimension>:
crate::image::NdImage + crate::conversions::NdAsImageMut<T, D>
{
fn gaussian_blur_inplace(
&mut self,
kernel_size: (u8, u8),
sigma_x: f64,
sigma_y: f64,
border_type: BorderType,
) -> Result<&mut Self, NdCvError>;
fn gaussian_blur_def_inplace(
&mut self,
kernel: (u8, u8),
sigma_x: f64,
) -> Result<&mut Self, NdCvError> {
self.gaussian_blur_inplace(kernel, sigma_x, sigma_x, BorderType::BorderConstant)
}
}
impl<
T: bytemuck::Pod + num::Zero + seal::Sealed,
S: ndarray::RawData + ndarray::DataMut<Elem = T>,
D: ndarray::Dimension,
> NdCvGaussianBlurInPlace<T, D> for ArrayBase<S, D>
where
Self: crate::image::NdImage + crate::conversions::NdAsImageMut<T, D>,
{
fn gaussian_blur_inplace(
&mut self,
kernel_size: (u8, u8),
sigma_x: f64,
sigma_y: f64,
border_type: BorderType,
) -> Result<&mut Self, NdCvError> {
let mut cv_self = self.as_image_mat_mut()?;
unsafe {
crate::inplace::op_inplace(&mut *cv_self, |this, out| {
opencv::imgproc::gaussian_blur(
this,
out,
opencv::core::Size {
width: kernel_size.0 as i32,
height: kernel_size.1 as i32,
},
sigma_x,
sigma_y,
border_type as i32,
)
})
}
.change_context(NdCvError)
.attach_printable("Failed to apply gaussian blur")?;
Ok(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array3;
#[test]
fn test_gaussian_basic() {
let arr = Array3::<u8>::ones((10, 10, 3));
let kernel_size = (3, 3);
let sigma_x = 0.0;
let sigma_y = 0.0;
let border_type = BorderType::BorderConstant;
let res = arr
.gaussian_blur(kernel_size, sigma_x, sigma_y, border_type)
.unwrap();
assert_eq!(res.shape(), &[10, 10, 3]);
}
#[test]
fn test_gaussian_edge_preservation() {
// Create an image with a sharp edge
let mut arr = Array3::<u8>::zeros((10, 10, 3));
arr.slice_mut(s![..5, .., ..]).fill(255); // Top half white, bottom half black
let res = arr
.gaussian_blur((3, 3), 1.0, 1.0, BorderType::BorderConstant)
.unwrap();
// Check that the middle row (edge) has intermediate values
let middle_row = res.slice(s![4..6, 5, 0]);
assert!(middle_row.iter().all(|&x| x > 0 && x < 255));
}
#[test]
fn test_gaussian_different_kernel_sizes() {
let arr = Array3::<u8>::ones((20, 20, 3));
// Test different kernel sizes
let kernel_sizes = [(3, 3), (5, 5), (7, 7)];
for &kernel_size in &kernel_sizes {
let res = arr
.gaussian_blur(kernel_size, 1.0, 1.0, BorderType::BorderConstant)
.unwrap();
assert_eq!(res.shape(), &[20, 20, 3]);
}
}
#[test]
fn test_gaussian_different_border_types() {
let mut arr = Array3::<u8>::zeros((10, 10, 3));
arr.slice_mut(s![4..7, 4..7, ..]).fill(255); // White square in center
let border_types = [
BorderType::BorderConstant,
BorderType::BorderReplicate,
BorderType::BorderReflect,
BorderType::BorderReflect101,
];
for border_type in border_types {
let res = arr.gaussian_blur((3, 3), 1.0, 1.0, border_type).unwrap();
assert_eq!(res.shape(), &[10, 10, 3]);
}
}
#[test]
fn test_gaussian_different_types() {
// Test with different numeric types
let arr_u8 = Array3::<u8>::ones((10, 10, 3));
let arr_f32 = Array3::<f32>::ones((10, 10, 3));
let res_u8 = arr_u8
.gaussian_blur((3, 3), 1.0, 1.0, BorderType::BorderConstant)
.unwrap();
let res_f32 = arr_f32
.gaussian_blur((3, 3), 1.0, 1.0, BorderType::BorderConstant)
.unwrap();
assert_eq!(res_u8.shape(), &[10, 10, 3]);
assert_eq!(res_f32.shape(), &[10, 10, 3]);
}
#[test]
#[should_panic]
fn test_gaussian_invalid_kernel_size() {
let arr = Array3::<u8>::ones((10, 10, 3));
// Even kernel sizes should fail
let _ = arr
.gaussian_blur((2, 2), 1.0, 1.0, BorderType::BorderConstant)
.unwrap();
}
}

30
ndcv-bridge/src/image.rs Normal file
View File

@@ -0,0 +1,30 @@
use ndarray::*;
pub trait NdImage {
fn width(&self) -> usize;
fn height(&self) -> usize;
fn channels(&self) -> usize;
}
impl<T, S: RawData<Elem = T>> NdImage for ArrayBase<S, Ix3> {
fn width(&self) -> usize {
self.dim().1
}
fn height(&self) -> usize {
self.dim().0
}
fn channels(&self) -> usize {
self.dim().2
}
}
impl<T, S: RawData<Elem = T>> NdImage for ArrayBase<S, Ix2> {
fn width(&self) -> usize {
self.dim().1
}
fn height(&self) -> usize {
self.dim().0
}
fn channels(&self) -> usize {
1
}
}

View File

@@ -0,0 +1,14 @@
use opencv::core::Mat;
use opencv::prelude::*;
use opencv::Result;
#[inline(always)]
pub(crate) unsafe fn op_inplace<T>(
m: &mut Mat,
f: impl FnOnce(&Mat, &mut Mat) -> Result<T>,
) -> Result<T> {
let mut m_alias = Mat::from_raw(m.as_raw_mut());
let out = f(m, &mut m_alias);
let _ = m_alias.into_raw();
out
}

83
ndcv-bridge/src/lib.rs Normal file
View File

@@ -0,0 +1,83 @@
//! Methods and type conversions for ndarray to opencv and vice versa
mod blend;
// mod dilate;
pub mod fir;
mod image;
mod inplace;
pub mod percentile;
mod roi;
#[cfg(feature = "opencv")]
pub mod bounding_rect;
// #[cfg(feature = "opencv")]
// pub mod color_space;
#[cfg(feature = "opencv")]
pub mod connected_components;
#[cfg(feature = "opencv")]
pub mod contours;
#[cfg(feature = "opencv")]
pub mod conversions;
// #[cfg(feature = "opencv")]
// pub mod gaussian;
#[cfg(feature = "opencv")]
pub mod resize;
pub mod codec;
pub mod orient;
pub use blend::NdBlend;
pub use fast_image_resize::{FilterType, ResizeAlg, ResizeOptions, Resizer};
pub use fir::NdFir;
// pub use gaussian::{BorderType, NdCvGaussianBlur, NdCvGaussianBlurInPlace};
pub use roi::{NdRoi, NdRoiMut, NdRoiZeroPadded};
#[cfg(feature = "opencv")]
pub use contours::{
ContourApproximationMethod, ContourHierarchy, ContourResult, ContourRetrievalMode,
NdCvContourArea, NdCvFindContours,
};
#[cfg(feature = "opencv")]
pub use bounding_rect::BoundingRect;
#[cfg(feature = "opencv")]
pub use connected_components::{Connectivity, NdCvConnectedComponents};
#[cfg(feature = "opencv")]
pub use conversions::{MatAsNd, NdAsImage, NdAsImageMut, NdAsMat, NdAsMatMut, NdCvConversion};
#[cfg(feature = "opencv")]
pub use resize::{Interpolation, NdCvResize};
pub(crate) mod prelude_ {
pub use crate::NdCvError;
pub use error_stack::*;
}
#[derive(Debug, thiserror::Error)]
#[error("NdCvError")]
pub struct NdCvError;
#[cfg(feature = "opencv")]
pub fn type_depth<T>() -> i32 {
match std::any::type_name::<T>() {
"u8" => opencv::core::CV_8U,
"i8" => opencv::core::CV_8S,
"u16" => opencv::core::CV_16U,
"i16" => opencv::core::CV_16S,
"i32" => opencv::core::CV_32S,
"f32" => opencv::core::CV_32F,
"f64" => opencv::core::CV_64F,
_ => panic!("Unsupported type"),
}
}
#[cfg(feature = "opencv")]
pub fn depth_type(depth: i32) -> &'static str {
match depth {
opencv::core::CV_8U => "u8",
opencv::core::CV_8S => "i8",
opencv::core::CV_16U => "u16",
opencv::core::CV_16S => "i16",
opencv::core::CV_32S => "i32",
opencv::core::CV_32F => "f32",
opencv::core::CV_64F => "f64",
_ => panic!("Unsupported depth"),
}
}

188
ndcv-bridge/src/orient.rs Normal file
View File

@@ -0,0 +1,188 @@
use ndarray::{Array, ArrayBase, ArrayView};
#[derive(Clone, Copy)]
pub enum Orientation {
NoRotation,
Mirror,
Clock180,
Water,
MirrorClock270,
Clock90,
MirrorClock90,
Clock270,
Unknown,
}
impl Orientation {
pub fn inverse(&self) -> Self {
match self {
Self::Clock90 => Self::Clock270,
Self::Clock270 => Self::Clock90,
_ => *self,
}
}
}
impl Orientation {
pub fn from_raw(flip: u8) -> Self {
match flip {
1 => Orientation::NoRotation,
2 => Orientation::Mirror,
3 => Orientation::Clock180,
4 => Orientation::Water,
5 => Orientation::MirrorClock270,
6 => Orientation::Clock90,
7 => Orientation::MirrorClock90,
8 => Orientation::Clock270,
_ => Orientation::Unknown,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RotationFlag {
Clock90,
Clock180,
Clock270,
}
impl RotationFlag {
pub fn neg(&self) -> Self {
match self {
RotationFlag::Clock90 => RotationFlag::Clock270,
RotationFlag::Clock180 => RotationFlag::Clock180,
RotationFlag::Clock270 => RotationFlag::Clock90,
}
}
pub fn to_orientation(&self) -> Orientation {
match self {
RotationFlag::Clock90 => Orientation::Clock90,
RotationFlag::Clock180 => Orientation::Clock180,
RotationFlag::Clock270 => Orientation::Clock270,
}
}
}
#[derive(Clone, Copy)]
pub enum FlipFlag {
Mirror,
Water,
Both,
}
pub trait Orient<T: bytemuck::Pod, D: ndarray::Dimension> {
fn flip(&self, flip: FlipFlag) -> Array<T, D>;
fn rotate(&self, rotation: RotationFlag) -> Array<T, D>;
fn owned(&self) -> Array<T, D>;
fn unorient(&self, orientation: Orientation) -> Array<T, D>
where
Array<T, D>: Orient<T, D>,
Self: ToOwned<Owned = Array<T, D>>,
{
let inverse_orientation = orientation.inverse();
self.orient(inverse_orientation)
// match orientation {
// Orientation::NoRotation | Orientation::Unknown => self.to_owned(),
// Orientation::Mirror => self.flip(FlipFlag::Mirror).to_owned(),
// Orientation::Clock180 => self.rotate(RotationFlag::Clock180),
// Orientation::Water => self.flip(FlipFlag::Water).to_owned(),
// Orientation::MirrorClock270 => self
// .rotate(RotationFlag::Clock90)
// .flip(FlipFlag::Mirror)
// .to_owned(),
// Orientation::Clock90 => self.rotate(RotationFlag::Clock270),
// Orientation::MirrorClock90 => self
// .rotate(RotationFlag::Clock270)
// .flip(FlipFlag::Mirror)
// .to_owned(),
// Orientation::Clock270 => self.rotate(RotationFlag::Clock90),
// }
}
fn orient(&self, orientation: Orientation) -> Array<T, D>
where
Array<T, D>: Orient<T, D>,
{
match orientation {
Orientation::NoRotation | Orientation::Unknown => self.owned(),
Orientation::Mirror => self.flip(FlipFlag::Mirror).to_owned(),
Orientation::Clock180 => self.rotate(RotationFlag::Clock180),
Orientation::Water => self.flip(FlipFlag::Water).to_owned(),
Orientation::MirrorClock270 => self
.flip(FlipFlag::Mirror)
.rotate(RotationFlag::Clock270)
.to_owned(),
Orientation::Clock90 => self.rotate(RotationFlag::Clock90),
Orientation::MirrorClock90 => self
.flip(FlipFlag::Mirror)
.rotate(RotationFlag::Clock90)
.to_owned(),
Orientation::Clock270 => self.rotate(RotationFlag::Clock270),
}
.as_standard_layout()
.to_owned()
}
}
impl<T: bytemuck::Pod + Copy, S: ndarray::Data<Elem = T>> Orient<T, ndarray::Ix3>
for ArrayBase<S, ndarray::Ix3>
{
fn flip(&self, flip: FlipFlag) -> Array<T, ndarray::Ix3> {
match flip {
FlipFlag::Mirror => self.slice(ndarray::s![.., ..;-1, ..]),
FlipFlag::Water => self.slice(ndarray::s![..;-1, .., ..]),
FlipFlag::Both => self.slice(ndarray::s![..;-1, ..;-1, ..]),
}
.as_standard_layout()
.to_owned()
}
fn owned(&self) -> Array<T, ndarray::Ix3> {
self.to_owned()
}
fn rotate(&self, rotation: RotationFlag) -> Array<T, ndarray::Ix3> {
match rotation {
RotationFlag::Clock90 => self
.view()
.permuted_axes([1, 0, 2])
.flip(FlipFlag::Mirror)
.to_owned(),
RotationFlag::Clock180 => self.flip(FlipFlag::Both).to_owned(),
RotationFlag::Clock270 => self
.view()
.permuted_axes([1, 0, 2])
.flip(FlipFlag::Water)
.to_owned(),
}
}
}
impl<T: bytemuck::Pod + Copy, S: ndarray::Data<Elem = T>> Orient<T, ndarray::Ix2>
for ArrayBase<S, ndarray::Ix2>
{
fn flip(&self, flip: FlipFlag) -> Array<T, ndarray::Ix2> {
match flip {
FlipFlag::Mirror => self.slice(ndarray::s![.., ..;-1,]),
FlipFlag::Water => self.slice(ndarray::s![..;-1, ..,]),
FlipFlag::Both => self.slice(ndarray::s![..;-1, ..;-1,]),
}
.as_standard_layout()
.to_owned()
}
fn owned(&self) -> Array<T, ndarray::Ix2> {
self.to_owned()
}
fn rotate(&self, rotation: RotationFlag) -> Array<T, ndarray::Ix2> {
match rotation {
RotationFlag::Clock90 => self.t().flip(FlipFlag::Mirror).to_owned(),
RotationFlag::Clock180 => self.flip(FlipFlag::Both).to_owned(),
RotationFlag::Clock270 => self.t().flip(FlipFlag::Water).to_owned(),
}
}
}

View File

@@ -0,0 +1,63 @@
use error_stack::*;
use ndarray::{ArrayBase, Ix1};
use num::cast::AsPrimitive;
use crate::NdCvError;
pub trait Percentile {
fn percentile(&self, qth_percentile: f64) -> Result<f64, NdCvError>;
}
impl<T: std::cmp::Ord + Clone + AsPrimitive<f64>, S: ndarray::Data<Elem = T>> Percentile
for ArrayBase<S, Ix1>
{
fn percentile(&self, qth_percentile: f64) -> Result<f64, NdCvError> {
if self.len() == 0 {
return Err(error_stack::Report::new(NdCvError).attach_printable("Empty Input"));
}
if !(0_f64..1_f64).contains(&qth_percentile) {
return Err(error_stack::Report::new(NdCvError)
.attach_printable("Qth percentile must be between 0 and 1"));
}
let mut standard_array = self.as_standard_layout();
let mut raw_data = standard_array
.as_slice_mut()
.expect("An array in standard layout will always return its inner slice");
raw_data.sort();
let actual_index = qth_percentile * (raw_data.len() - 1) as f64;
let lower_index = (actual_index.floor() as usize).clamp(0, raw_data.len() - 1);
let upper_index = (actual_index.ceil() as usize).clamp(0, raw_data.len() - 1);
if lower_index == upper_index {
Ok(raw_data[lower_index].as_())
} else {
let weight = actual_index - lower_index as f64;
Ok(raw_data[lower_index].as_() * (1.0 - weight) + raw_data[upper_index].as_() * weight)
}
}
}
// fn percentile(data: &Array1<f64>, p: f64) -> f64 {
// if data.len() == 0 {
// return 0.0;
// }
//
// let mut sorted_data = data.to_vec();
// sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
//
// let index = (p / 100.0) * (sorted_data.len() - 1) as f64;
// let lower = index.floor() as usize;
// let upper = index.ceil() as usize;
//
// if lower == upper {
// sorted_data[lower] as f64
// } else {
// let weight = index - lower as f64;
// sorted_data[lower] as f64 * (1.0 - weight) + sorted_data[upper] as f64 * weight
// }
// }

108
ndcv-bridge/src/resize.rs Normal file
View File

@@ -0,0 +1,108 @@
use crate::{prelude_::*, NdAsImage, NdAsImageMut};
/// Resize ndarray using OpenCV resize functions
pub trait NdCvResize<T, D>: seal::SealedInternal {
/// The input array must be a continuous 2D or 3D ndarray
fn resize(
&self,
height: u16,
width: u16,
interpolation: Interpolation,
) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<T>, D>, NdCvError>;
}
#[repr(i32)]
#[derive(Debug, Copy, Clone)]
pub enum Interpolation {
Linear = opencv::imgproc::INTER_LINEAR,
LinearExact = opencv::imgproc::INTER_LINEAR_EXACT,
Max = opencv::imgproc::INTER_MAX,
Area = opencv::imgproc::INTER_AREA,
Cubic = opencv::imgproc::INTER_CUBIC,
Nearest = opencv::imgproc::INTER_NEAREST,
NearestExact = opencv::imgproc::INTER_NEAREST_EXACT,
Lanczos4 = opencv::imgproc::INTER_LANCZOS4,
}
mod seal {
pub trait SealedInternal {}
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>> SealedInternal
for ndarray::ArrayBase<S, ndarray::Ix3>
{
}
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>> SealedInternal
for ndarray::ArrayBase<S, ndarray::Ix2>
{
}
}
impl<T: bytemuck::Pod + num::Zero, S: ndarray::Data<Elem = T>> NdCvResize<T, ndarray::Ix2>
for ndarray::ArrayBase<S, ndarray::Ix2>
{
fn resize(
&self,
height: u16,
width: u16,
interpolation: Interpolation,
) -> Result<ndarray::Array2<T>, NdCvError> {
let mat = self.as_image_mat()?;
let mut dest = ndarray::Array2::zeros((height.into(), width.into()));
let mut dest_mat = dest.as_image_mat_mut()?;
opencv::imgproc::resize(
mat.as_ref(),
dest_mat.as_mut(),
opencv::core::Size {
height: height.into(),
width: width.into(),
},
0.,
0.,
interpolation as i32,
)
.change_context(NdCvError)?;
Ok(dest)
}
}
impl<T: bytemuck::Pod + num::Zero, S: ndarray::Data<Elem = T>> NdCvResize<T, ndarray::Ix3>
for ndarray::ArrayBase<S, ndarray::Ix3>
{
fn resize(
&self,
height: u16,
width: u16,
interpolation: Interpolation,
) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<T>, ndarray::Ix3>, NdCvError> {
let mat = self.as_image_mat()?;
let mut dest =
ndarray::Array3::zeros((height.into(), width.into(), self.len_of(ndarray::Axis(2))));
let mut dest_mat = dest.as_image_mat_mut()?;
opencv::imgproc::resize(
mat.as_ref(),
dest_mat.as_mut(),
opencv::core::Size {
height: height.into(),
width: width.into(),
},
0.,
0.,
interpolation as i32,
)
.change_context(NdCvError)?;
Ok(dest)
}
}
#[test]
fn test_resize_simple() {
let foo = ndarray::Array2::<u8>::ones((10, 10));
let foo_resized = foo.resize(15, 20, Interpolation::Linear).unwrap();
assert_eq!(foo_resized, ndarray::Array2::<u8>::ones((15, 20)));
}
#[test]
fn test_resize_3d() {
let foo = ndarray::Array3::<u8>::ones((10, 10, 3));
let foo_resized = foo.resize(15, 20, Interpolation::Linear).unwrap();
assert_eq!(foo_resized, ndarray::Array3::<u8>::ones((15, 20, 3)));
}

274
ndcv-bridge/src/roi.rs Normal file
View File

@@ -0,0 +1,274 @@
pub trait NdRoi<T, D>: seal::Sealed {
fn roi(&self, rect: bounding_box::Aabb2<usize>) -> ndarray::ArrayView<T, D>;
}
pub trait NdRoiMut<T, D>: seal::Sealed {
fn roi_mut(&mut self, rect: bounding_box::Aabb2<usize>) -> ndarray::ArrayViewMut<T, D>;
}
mod seal {
use ndarray::{Ix2, Ix3};
pub trait Sealed {}
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>> Sealed for ndarray::ArrayBase<S, Ix2> {}
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>> Sealed for ndarray::ArrayBase<S, Ix3> {}
}
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>> NdRoi<T, ndarray::Ix3>
for ndarray::ArrayBase<S, ndarray::Ix3>
{
fn roi(&self, rect: bounding_box::Aabb2<usize>) -> ndarray::ArrayView3<T> {
let y1 = rect.y1();
let y2 = rect.y2();
let x1 = rect.x1();
let x2 = rect.x2();
self.slice(ndarray::s![y1..y2, x1..x2, ..])
}
}
impl<T: bytemuck::Pod, S: ndarray::DataMut<Elem = T>> NdRoiMut<T, ndarray::Ix3>
for ndarray::ArrayBase<S, ndarray::Ix3>
{
fn roi_mut(&mut self, rect: bounding_box::Aabb2<usize>) -> ndarray::ArrayViewMut3<T> {
let y1 = rect.y1();
let y2 = rect.y2();
let x1 = rect.x1();
let x2 = rect.x2();
self.slice_mut(ndarray::s![y1..y2, x1..x2, ..])
}
}
impl<T: bytemuck::Pod, S: ndarray::Data<Elem = T>> NdRoi<T, ndarray::Ix2>
for ndarray::ArrayBase<S, ndarray::Ix2>
{
fn roi(&self, rect: bounding_box::Aabb2<usize>) -> ndarray::ArrayView2<T> {
let y1 = rect.y1();
let y2 = rect.y2();
let x1 = rect.x1();
let x2 = rect.x2();
self.slice(ndarray::s![y1..y2, x1..x2])
}
}
impl<T: bytemuck::Pod, S: ndarray::DataMut<Elem = T>> NdRoiMut<T, ndarray::Ix2>
for ndarray::ArrayBase<S, ndarray::Ix2>
{
fn roi_mut(&mut self, rect: bounding_box::Aabb2<usize>) -> ndarray::ArrayViewMut2<T> {
let y1 = rect.y1();
let y2 = rect.y2();
let x1 = rect.x1();
let x2 = rect.x2();
self.slice_mut(ndarray::s![y1..y2, x1..x2])
}
}
#[test]
fn test_roi() {
let arr = ndarray::Array3::<u8>::zeros((10, 10, 3));
let rect = bounding_box::Aabb2::from_xywh(1, 1, 3, 3);
let roi = arr.roi(rect);
assert_eq!(roi.shape(), &[3, 3, 3]);
}
#[test]
fn test_roi_2d() {
let arr = ndarray::Array2::<u8>::zeros((10, 10));
let rect = bounding_box::Aabb2::from_xywh(1, 1, 3, 3);
let roi = arr.roi(rect);
assert_eq!(roi.shape(), &[3, 3]);
}
/// ```text
/// ┌──────────────────┐
/// │ padded │
/// │ ┌────────┐ │
/// │ │ │ │
/// │ │original│ │
/// │ │ │ │
/// │ └────────┘ │
/// │ zeroed │
/// └──────────────────┘
/// ```
///
/// Returns the padded bounding box and the padded segment
/// The padded is the padded bounding box
/// The original is the original bounding box
/// Returns the padded bounding box as zeros and the original bbox as the original segment
// Helper functions for missing methods from old bbox crate
fn bbox_top_left_usize(bbox: &bounding_box::Aabb2<usize>) -> (usize, usize) {
(bbox.x1(), bbox.y1())
}
fn bbox_with_top_left_usize(
bbox: &bounding_box::Aabb2<usize>,
x: usize,
y: usize,
) -> bounding_box::Aabb2<usize> {
let width = bbox.x2() - bbox.x1();
let height = bbox.y2() - bbox.y1();
bounding_box::Aabb2::from_xywh(x, y, width, height)
}
fn bbox_with_origin_usize(point: (usize, usize), origin: (usize, usize)) -> (usize, usize) {
(point.0 - origin.0, point.1 - origin.1)
}
fn bbox_zeros_ndarray_2d<T: num::Zero + Copy>(
bbox: &bounding_box::Aabb2<usize>,
) -> ndarray::Array2<T> {
let width = bbox.x2() - bbox.x1();
let height = bbox.y2() - bbox.y1();
ndarray::Array2::<T>::zeros((height, width))
}
fn bbox_zeros_ndarray_3d<T: num::Zero + Copy>(
bbox: &bounding_box::Aabb2<usize>,
channels: usize,
) -> ndarray::Array3<T> {
let width = bbox.x2() - bbox.x1();
let height = bbox.y2() - bbox.y1();
ndarray::Array3::<T>::zeros((height, width, channels))
}
fn bbox_round_f64(bbox: &bounding_box::Aabb2<f64>) -> bounding_box::Aabb2<f64> {
let x1 = bbox.x1().round();
let y1 = bbox.y1().round();
let x2 = bbox.x2().round();
let y2 = bbox.y2().round();
bounding_box::Aabb2::from_x1y1x2y2(x1, y1, x2, y2)
}
fn bbox_cast_f64_to_usize(bbox: &bounding_box::Aabb2<f64>) -> bounding_box::Aabb2<usize> {
let x1 = bbox.x1() as usize;
let y1 = bbox.y1() as usize;
let x2 = bbox.x2() as usize;
let y2 = bbox.y2() as usize;
bounding_box::Aabb2::from_x1y1x2y2(x1, y1, x2, y2)
}
pub trait NdRoiZeroPadded<T, D: ndarray::Dimension>: seal::Sealed {
fn roi_zero_padded(
&self,
original: bounding_box::Aabb2<usize>,
padded: bounding_box::Aabb2<usize>,
) -> (bounding_box::Aabb2<usize>, ndarray::Array<T, D>);
}
impl<T: bytemuck::Pod + num::Zero> NdRoiZeroPadded<T, ndarray::Ix2> for ndarray::Array2<T> {
fn roi_zero_padded(
&self,
original: bounding_box::Aabb2<usize>,
padded: bounding_box::Aabb2<usize>,
) -> (bounding_box::Aabb2<usize>, ndarray::Array2<T>) {
// The co-ordinates of both the original and the padded bounding boxes must be contained in
// self or it will panic
let self_bbox = bounding_box::Aabb2::from_xywh(0, 0, self.shape()[1], self.shape()[0]);
if !self_bbox.contains_bbox(&original) {
panic!("original bounding box is not contained in self");
}
if !self_bbox.contains_bbox(&padded) {
panic!("padded bounding box is not contained in self");
}
let original_top_left = bbox_top_left_usize(&original);
let padded_top_left = bbox_top_left_usize(&padded);
let origin_offset = bbox_with_origin_usize(original_top_left, padded_top_left);
let original_roi_in_padded =
bbox_with_top_left_usize(&original, origin_offset.0, origin_offset.1);
let original_segment = self.roi(original);
let mut padded_segment = bbox_zeros_ndarray_2d::<T>(&padded);
padded_segment
.roi_mut(original_roi_in_padded)
.assign(&original_segment);
(padded, padded_segment)
}
}
impl<T: bytemuck::Pod + num::Zero> NdRoiZeroPadded<T, ndarray::Ix3> for ndarray::Array3<T> {
fn roi_zero_padded(
&self,
original: bounding_box::Aabb2<usize>,
padded: bounding_box::Aabb2<usize>,
) -> (bounding_box::Aabb2<usize>, ndarray::Array3<T>) {
let self_bbox = bounding_box::Aabb2::from_xywh(0, 0, self.shape()[1], self.shape()[0]);
if !self_bbox.contains_bbox(&original) {
panic!("original bounding box is not contained in self");
}
if !self_bbox.contains_bbox(&padded) {
panic!("padded bounding box is not contained in self");
}
let original_top_left = bbox_top_left_usize(&original);
let padded_top_left = bbox_top_left_usize(&padded);
let origin_offset = bbox_with_origin_usize(original_top_left, padded_top_left);
let original_roi_in_padded =
bbox_with_top_left_usize(&original, origin_offset.0, origin_offset.1);
let original_segment = self.roi(original);
let mut padded_segment = bbox_zeros_ndarray_3d::<T>(&padded, self.len_of(ndarray::Axis(2)));
padded_segment
.roi_mut(original_roi_in_padded)
.assign(&original_segment);
(padded, padded_segment)
}
}
#[test]
fn test_roi_zero_padded() {
let arr = ndarray::Array2::<u8>::ones((10, 10));
let original = bounding_box::Aabb2::from_xywh(1.0, 1.0, 3.0, 3.0);
let clamp = bounding_box::Aabb2::from_xywh(0.0, 0.0, 10.0, 10.0);
let padded = original.padding(2.0).clamp(&clamp).unwrap();
let padded_cast = bbox_cast_f64_to_usize(&padded);
let original_cast = bbox_cast_f64_to_usize(&original);
let (padded_result, padded_segment) = arr.roi_zero_padded(original_cast, padded_cast);
assert_eq!(padded_result, bounding_box::Aabb2::from_xywh(0, 0, 6, 6));
assert_eq!(padded_segment.shape(), &[6, 6]);
}
#[test]
pub fn bbox_clamp_failure_preload() {
let segment_mask = ndarray::Array2::<u8>::zeros((512, 512));
let og = bounding_box::Aabb2::from_xywh(475.0, 79.625, 37.0, 282.15);
let clamp = bounding_box::Aabb2::from_xywh(0.0, 0.0, 512.0, 512.0);
let padded = og
.scale(nalgebra::Vector2::new(1.2, 1.2))
.clamp(&clamp)
.unwrap();
let padded = bbox_round_f64(&padded);
let og_cast = bbox_cast_f64_to_usize(&bbox_round_f64(&og));
let padded_cast = bbox_cast_f64_to_usize(&padded);
let (_bbox, _segment) = segment_mask.roi_zero_padded(og_cast, padded_cast);
}
#[test]
pub fn bbox_clamp_failure_preload_2() {
let segment_mask = ndarray::Array2::<u8>::zeros((512, 512));
let bbox = bounding_box::Aabb2::from_xywh(354.25, 98.5, 116.25, 413.5);
// let padded = bounding_box::Aabb2::from_xywh(342.625, 57.150000000000006, 139.5, 454.85);
let clamp = bounding_box::Aabb2::from_xywh(0.0, 0.0, 512.0, 512.0);
let padded = bbox
.scale(nalgebra::Vector2::new(1.2, 1.2))
.clamp(&clamp)
.unwrap();
let padded = bbox_round_f64(&padded);
let bbox_cast = bbox_cast_f64_to_usize(&bbox_round_f64(&bbox));
let padded_cast = bbox_cast_f64_to_usize(&padded);
let (_bbox, _segment) = segment_mask.roi_zero_padded(bbox_cast, padded_cast);
}
#[test]
fn test_roi_zero_padded_3d() {
let arr = ndarray::Array3::<u8>::ones((10, 10, 3));
let original = bounding_box::Aabb2::from_xywh(1.0, 1.0, 3.0, 3.0);
let clamp = bounding_box::Aabb2::from_xywh(0.0, 0.0, 10.0, 10.0);
let padded = original.padding(2.0).clamp(&clamp).unwrap();
let padded_cast = bbox_cast_f64_to_usize(&padded);
let original_cast = bbox_cast_f64_to_usize(&original);
let (padded_result, padded_segment) = arr.roi_zero_padded(original_cast, padded_cast);
assert_eq!(padded_result, bounding_box::Aabb2::from_xywh(0, 0, 6, 6));
assert_eq!(padded_segment.shape(), &[6, 6, 3]);
}

1
rfcs

Submodule rfcs deleted from c973203daf

View File

@@ -0,0 +1,14 @@
[package]
name = "sqlite3-ndarray-math"
version.workspace = true
edition.workspace = true
[lib]
crate-type = ["cdylib", "staticlib"]
[dependencies]
ndarray = "0.16.1"
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
# ndarray-math = { path = "/Users/fs0c131y/Projects/ndarray-math", version = "0.1.0" }
ndarray-safetensors = { version = "0.1.0", path = "../ndarray-safetensors" }
sqlite-loadable = "0.0.5"

View File

@@ -0,0 +1,61 @@
use sqlite_loadable::prelude::*;
use sqlite_loadable::{Error, ErrorKind};
use sqlite_loadable::{Result, api, define_scalar_function};
fn cosine_similarity(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
#[inline(always)]
fn custom_error(err: impl core::error::Error) -> sqlite_loadable::Error {
sqlite_loadable::Error::new(sqlite_loadable::ErrorKind::Message(err.to_string()))
}
if values.len() != 2 {
return Err(Error::new(ErrorKind::Message(
"cosine_similarity requires exactly 2 arguments".to_string(),
)));
}
let array_1 = api::value_blob(values.get(0).expect("1st argument"));
let array_2 = api::value_blob(values.get(1).expect("2nd argument"));
let array_1_st =
ndarray_safetensors::SafeArraysView::from_bytes(array_1).map_err(custom_error)?;
let array_2_st =
ndarray_safetensors::SafeArraysView::from_bytes(array_2).map_err(custom_error)?;
let array_view_1 = array_1_st
.tensor_by_index::<f32, ndarray::Ix1>(0)
.map_err(custom_error)?;
let array_view_2 = array_2_st
.tensor_by_index::<f32, ndarray::Ix1>(0)
.map_err(custom_error)?;
use ndarray_math::*;
let similarity = array_view_1
.cosine_similarity(array_view_2)
.map_err(custom_error)?;
api::result_double(context, similarity as f64);
Ok(())
}
pub fn _sqlite3_extension_init(db: *mut sqlite3) -> Result<()> {
define_scalar_function(
db,
"cosine_similarity",
2,
cosine_similarity,
FunctionFlags::DETERMINISTIC,
)?;
Ok(())
}
/// # Safety
///
/// Should only be called by underlying SQLite C APIs,
/// like sqlite3_auto_extension and sqlite3_cancel_auto_extension.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sqlite3_extension_init(
db: *mut sqlite3,
pz_err_msg: *mut *mut c_char,
p_api: *mut sqlite3_api_routines,
) -> c_uint {
register_entrypoint(db, pz_err_msg, p_api, _sqlite3_extension_init)
}

213
src/bin/detector-cli/cli.rs Normal file
View File

@@ -0,0 +1,213 @@
use detector::ort_ep;
use std::path::PathBuf;
#[derive(Debug, clap::Parser)]
pub struct Cli {
#[clap(subcommand)]
pub cmd: SubCommand,
}
#[derive(Debug, clap::Subcommand)]
pub enum SubCommand {
#[clap(name = "detect")]
Detect(Detect),
#[clap(name = "detect-multi")]
DetectMulti(DetectMulti),
#[clap(name = "query")]
Query(Query),
#[clap(name = "similar")]
Similar(Similar),
#[clap(name = "stats")]
Stats(Stats),
#[clap(name = "compare")]
Compare(Compare),
#[clap(name = "cluster")]
Cluster(Cluster),
#[clap(name = "gui")]
Gui,
#[clap(name = "completions")]
Completions { shell: clap_complete::Shell },
}
#[derive(Debug, clap::ValueEnum, Clone, Copy, PartialEq)]
pub enum Models {
RetinaFace,
Yolo,
}
#[derive(Debug, Clone)]
pub enum Executor {
Mnn(mnn::ForwardType),
Ort(Vec<ort_ep::ExecutionProvider>),
}
#[derive(Debug, clap::Args)]
pub struct Detect {
#[clap(short, long)]
pub model: Option<PathBuf>,
#[clap(short = 'M', long, default_value = "retina-face")]
pub model_type: Models,
#[clap(short, long)]
pub output: Option<PathBuf>,
#[clap(
short = 'p',
long,
default_value = "cpu",
group = "execution_provider",
required_unless_present = "mnn_forward_type"
)]
pub ort_execution_provider: Vec<ort_ep::ExecutionProvider>,
#[clap(
short = 'f',
long,
group = "execution_provider",
required_unless_present = "ort_execution_provider"
)]
pub mnn_forward_type: Option<mnn::ForwardType>,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)]
pub nms_threshold: f32,
#[clap(short, long, default_value_t = 8)]
pub batch_size: usize,
#[clap(short = 'd', long)]
pub database: Option<PathBuf>,
#[clap(long, default_value = "facenet")]
pub model_name: String,
#[clap(long)]
pub save_to_db: bool,
pub image: PathBuf,
}
#[derive(Debug, clap::Args)]
pub struct DetectMulti {
#[clap(short, long)]
pub model: Option<PathBuf>,
#[clap(short = 'M', long, default_value = "retina-face")]
pub model_type: Models,
#[clap(short, long)]
pub output_dir: Option<PathBuf>,
#[clap(
short = 'p',
long,
default_value = "cpu",
group = "execution_provider",
required_unless_present = "mnn_forward_type"
)]
pub ort_execution_provider: Vec<ort_ep::ExecutionProvider>,
#[clap(
short = 'f',
long,
group = "execution_provider",
required_unless_present = "ort_execution_provider"
)]
pub mnn_forward_type: Option<mnn::ForwardType>,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)]
pub nms_threshold: f32,
#[clap(short, long, default_value_t = 8)]
pub batch_size: usize,
#[clap(short = 'd', long, default_value = "face_detections.db")]
pub database: PathBuf,
#[clap(long, default_value = "facenet")]
pub model_name: String,
#[clap(
long,
help = "Image extensions to process (e.g., jpg,png,jpeg)",
default_value = "jpg,jpeg,png,bmp,tiff,webp"
)]
pub extensions: String,
#[clap(help = "Directory containing images to process")]
pub input_dir: PathBuf,
}
#[derive(Debug, clap::Args)]
pub struct Query {
#[clap(short = 'd', long, default_value = "face_detections.db")]
pub database: PathBuf,
#[clap(short, long)]
pub image_id: Option<i64>,
#[clap(short, long)]
pub face_id: Option<i64>,
#[clap(long)]
pub show_embeddings: bool,
#[clap(long)]
pub show_landmarks: bool,
}
#[derive(Debug, clap::Args)]
pub struct Similar {
#[clap(short = 'd', long, default_value = "face_detections.db")]
pub database: PathBuf,
#[clap(short, long)]
pub face_id: i64,
#[clap(short, long, default_value_t = 0.7)]
pub threshold: f32,
#[clap(short, long, default_value_t = 10)]
pub limit: usize,
}
#[derive(Debug, clap::Args)]
pub struct Stats {
#[clap(short = 'd', long, default_value = "face_detections.db")]
pub database: PathBuf,
}
#[derive(Debug, clap::Args)]
pub struct Compare {
#[clap(short, long)]
pub model: Option<PathBuf>,
#[clap(short = 'M', long, default_value = "retina-face")]
pub model_type: Models,
#[clap(
short = 'p',
long,
default_value = "cpu",
group = "execution_provider",
required_unless_present = "mnn_forward_type"
)]
pub ort_execution_provider: Vec<ort_ep::ExecutionProvider>,
#[clap(
short = 'f',
long,
group = "execution_provider",
required_unless_present = "ort_execution_provider"
)]
pub mnn_forward_type: Option<mnn::ForwardType>,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)]
pub nms_threshold: f32,
#[clap(short, long, default_value_t = 8)]
pub batch_size: usize,
#[clap(long, default_value = "facenet")]
pub model_name: String,
#[clap(help = "First image to compare")]
pub image1: PathBuf,
#[clap(help = "Second image to compare")]
pub image2: PathBuf,
}
#[derive(Debug, clap::Args)]
pub struct Cluster {
#[clap(short = 'd', long, default_value = "face_detections.db")]
pub database: PathBuf,
#[clap(short, long, default_value_t = 5)]
pub clusters: usize,
#[clap(short, long, default_value_t = 100)]
pub max_iterations: u64,
#[clap(short, long, default_value_t = 1e-4)]
pub tolerance: f64,
#[clap(long, default_value = "facenet")]
pub model_name: String,
#[clap(short, long)]
pub output: Option<PathBuf>,
}
impl Cli {
pub fn completions(shell: clap_complete::Shell) {
let mut command = <Cli as clap::CommandFactory>::command();
clap_complete::generate(shell, &mut command, "detector", &mut std::io::stdout());
}
}

1063
src/bin/detector-cli/main.rs Normal file

File diff suppressed because it is too large Load Diff

19
src/bin/gui.rs Normal file
View File

@@ -0,0 +1,19 @@
use detector::errors::*;
fn main() -> Result<()> {
// Initialize logging
tracing_subscriber::fmt()
.with_env_filter("warn,ort=warn")
.with_file(true)
.with_line_number(true)
// .with_thread_names(true)
.with_target(true)
.init();
// Run the GUI
if let Err(e) = detector::gui::run() {
eprintln!("GUI error: {}", e);
std::process::exit(1);
}
Ok(())
}

View File

@@ -1,77 +0,0 @@
use std::path::PathBuf;
use mnn::ForwardType;
#[derive(Debug, clap::Parser)]
pub struct Cli {
#[clap(subcommand)]
pub cmd: SubCommand,
}
#[derive(Debug, clap::Subcommand)]
pub enum SubCommand {
#[clap(name = "detect")]
Detect(Detect),
#[clap(name = "list")]
List(List),
#[clap(name = "completions")]
Completions { shell: clap_complete::Shell },
}
#[derive(Debug, clap::ValueEnum, Clone, Copy)]
pub enum Models {
RetinaFace,
Yolo,
}
#[derive(Debug, Clone)]
pub enum Executor {
Mnn(mnn::ForwardType),
Ort(Vec<detector::ort_ep::ExecutionProvider>),
}
#[derive(Debug, clap::Args)]
pub struct Detect {
#[clap(short, long)]
pub model: Option<PathBuf>,
#[clap(short = 'M', long, default_value = "retina-face")]
pub model_type: Models,
#[clap(short, long)]
pub output: Option<PathBuf>,
#[clap(
short = 'p',
long,
default_value = "cpu",
group = "execution_provider",
required_unless_present = "mnn_forward_type"
)]
pub ort_execution_provider: Vec<detector::ort_ep::ExecutionProvider>,
#[clap(
short = 'f',
long,
group = "execution_provider",
required_unless_present = "ort_execution_provider"
)]
pub mnn_forward_type: Option<mnn::ForwardType>,
#[clap(short, long, default_value_t = 0.8)]
pub threshold: f32,
#[clap(short, long, default_value_t = 0.3)]
pub nms_threshold: f32,
#[clap(short, long, default_value_t = 8)]
pub batch_size: usize,
pub image: PathBuf,
}
#[derive(Debug, clap::Args)]
pub struct List {}
impl Cli {
pub fn completions(shell: clap_complete::Shell) {
let mut command = <Cli as clap::CommandFactory>::command();
clap_complete::generate(
shell,
&mut command,
env!("CARGO_BIN_NAME"),
&mut std::io::stdout(),
);
}
}

734
src/database.rs Normal file
View File

@@ -0,0 +1,734 @@
use crate::errors::{Error, Result};
use crate::facedet::{FaceDetectionOutput, FaceLandmarks};
use bounding_box::Aabb2;
use error_stack::ResultExt;
use ndarray_math::CosineSimilarity;
use rusqlite::{Connection, OptionalExtension, params};
use std::path::Path;
/// Database connection and operations for face detection results
pub struct FaceDatabase {
conn: Connection,
}
/// Represents a stored image record
#[derive(Debug, Clone)]
pub struct ImageRecord {
pub id: i64,
pub file_path: String,
pub width: u32,
pub height: u32,
pub created_at: String,
}
/// Represents a stored face detection record
#[derive(Debug, Clone)]
pub struct FaceRecord {
pub id: i64,
pub image_id: i64,
pub bbox_x1: f32,
pub bbox_y1: f32,
pub bbox_x2: f32,
pub bbox_y2: f32,
pub confidence: f32,
pub created_at: String,
}
/// Represents stored face landmarks
#[derive(Debug, Clone)]
pub struct LandmarkRecord {
pub id: i64,
pub face_id: i64,
pub left_eye_x: f32,
pub left_eye_y: f32,
pub right_eye_x: f32,
pub right_eye_y: f32,
pub nose_x: f32,
pub nose_y: f32,
pub left_mouth_x: f32,
pub left_mouth_y: f32,
pub right_mouth_x: f32,
pub right_mouth_y: f32,
}
/// Represents a stored face embedding
#[derive(Debug, Clone)]
pub struct EmbeddingRecord {
pub id: i64,
pub face_id: i64,
pub embedding: ndarray::Array1<f32>,
pub model_name: String,
pub created_at: String,
}
impl FaceDatabase {
/// Create a new database connection and initialize tables
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
let conn = Connection::open(db_path).change_context(Error)?;
// Temporarily disable extension loading for clustering demo
// unsafe {
// let _guard = rusqlite::LoadExtensionGuard::new(&conn).change_context(Error)?;
// conn.load_extension(
// "/Users/fs0c131y/.cache/cargo/target/release/libsqlite3_safetensor_cosine.dylib",
// None::<&str>,
// )
// .change_context(Error)?;
// }
let db = Self { conn };
db.create_tables()?;
Ok(db)
}
/// Create an in-memory database for testing
pub fn in_memory() -> Result<Self> {
let conn = Connection::open_in_memory().change_context(Error)?;
let db = Self { conn };
db.create_tables()?;
Ok(db)
}
/// Create all necessary database tables
fn create_tables(&self) -> Result<()> {
// Images table
self.conn
.execute(
r#"
CREATE TABLE IF NOT EXISTS images (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_path TEXT NOT NULL UNIQUE,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"#,
[],
)
.change_context(Error)?;
// Faces table
self.conn
.execute(
r#"
CREATE TABLE IF NOT EXISTS faces (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_id INTEGER NOT NULL,
bbox_x1 REAL NOT NULL,
bbox_y1 REAL NOT NULL,
bbox_x2 REAL NOT NULL,
bbox_y2 REAL NOT NULL,
confidence REAL NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
)
"#,
[],
)
.change_context(Error)?;
// Landmarks table
self.conn
.execute(
r#"
CREATE TABLE IF NOT EXISTS landmarks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
face_id INTEGER NOT NULL,
left_eye_x REAL NOT NULL,
left_eye_y REAL NOT NULL,
right_eye_x REAL NOT NULL,
right_eye_y REAL NOT NULL,
nose_x REAL NOT NULL,
nose_y REAL NOT NULL,
left_mouth_x REAL NOT NULL,
left_mouth_y REAL NOT NULL,
right_mouth_x REAL NOT NULL,
right_mouth_y REAL NOT NULL,
FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE
)
"#,
[],
)
.change_context(Error)?;
// Embeddings table
self.conn
.execute(
r#"
CREATE TABLE IF NOT EXISTS embeddings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
face_id INTEGER NOT NULL,
embedding BLOB NOT NULL,
model_name TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (face_id) REFERENCES faces (id) ON DELETE CASCADE
)
"#,
[],
)
.change_context(Error)?;
// Create indexes for better performance
self.conn
.execute(
"CREATE INDEX IF NOT EXISTS idx_faces_image_id ON faces (image_id)",
[],
)
.change_context(Error)?;
self.conn
.execute(
"CREATE INDEX IF NOT EXISTS idx_landmarks_face_id ON landmarks (face_id)",
[],
)
.change_context(Error)?;
self.conn
.execute(
"CREATE INDEX IF NOT EXISTS idx_embeddings_face_id ON embeddings (face_id)",
[],
)
.change_context(Error)?;
Ok(())
}
/// Store image metadata and return the image ID
pub fn store_image(&self, file_path: &str, width: u32, height: u32) -> Result<i64> {
let mut stmt = self
.conn
.prepare("INSERT OR REPLACE INTO images (file_path, width, height) VALUES (?1, ?2, ?3)")
.change_context(Error)?;
Ok(stmt
.insert(params![file_path, width, height])
.change_context(Error)?)
}
/// Store face detection results
pub fn store_face_detections(
&self,
image_id: i64,
detection_output: &FaceDetectionOutput,
) -> Result<Vec<i64>> {
let mut face_ids = Vec::new();
for (i, bbox) in detection_output.bbox.iter().enumerate() {
let confidence = detection_output.confidence.get(i).copied().unwrap_or(0.0);
let face_id = self.store_face(image_id, bbox, confidence)?;
face_ids.push(face_id);
// Store landmarks if available
if let Some(landmarks) = detection_output.landmark.get(i) {
self.store_landmarks(face_id, landmarks)?;
}
}
Ok(face_ids)
}
/// Store a single face detection
pub fn store_face(&self, image_id: i64, bbox: &Aabb2<usize>, confidence: f32) -> Result<i64> {
let mut stmt = self
.conn
.prepare(
r#"
INSERT INTO faces (image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
"#,
)
.change_context(Error)?;
Ok(stmt
.insert(params![
image_id,
bbox.x1() as f32,
bbox.y1() as f32,
bbox.x2() as f32,
bbox.y2() as f32,
confidence
])
.change_context(Error)?)
}
/// Store face landmarks
pub fn store_landmarks(&self, face_id: i64, landmarks: &FaceLandmarks) -> Result<i64> {
let mut stmt = self
.conn
.prepare(
r#"
INSERT INTO landmarks
(face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y,
nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)
"#,
)
.change_context(Error)?;
Ok(stmt
.insert(params![
face_id,
landmarks.left_eye.x,
landmarks.left_eye.y,
landmarks.right_eye.x,
landmarks.right_eye.y,
landmarks.nose.x,
landmarks.nose.y,
landmarks.left_mouth.x,
landmarks.left_mouth.y,
landmarks.right_mouth.x,
landmarks.right_mouth.y,
])
.change_context(Error)?)
}
/// Store face embeddings
pub fn store_embeddings(
&self,
face_ids: &[i64],
embeddings: &[ndarray::Array2<f32>],
model_name: &str,
) -> Result<Vec<i64>> {
let mut embedding_ids = Vec::new();
for (face_idx, embedding_batch) in embeddings.iter().enumerate() {
for (batch_idx, embedding_row) in embedding_batch.rows().into_iter().enumerate() {
let global_idx = face_idx * embedding_batch.nrows() + batch_idx;
if global_idx >= face_ids.len() {
break;
}
let face_id = face_ids[global_idx];
let embedding_id =
self.store_single_embedding(face_id, embedding_row, model_name)?;
embedding_ids.push(embedding_id);
}
}
Ok(embedding_ids)
}
/// Store a single embedding
pub fn store_single_embedding(
&self,
face_id: i64,
embedding: ndarray::ArrayView1<f32>,
model_name: &str,
) -> Result<i64> {
let safe_arrays =
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding)])
.change_context(Error)?;
let embedding_bytes = safe_arrays.serialize().change_context(Error)?;
let mut stmt = self
.conn
.prepare("INSERT INTO embeddings (face_id, embedding, model_name) VALUES (?1, ?2, ?3)")
.change_context(Error)?;
stmt.execute(params![face_id, embedding_bytes, model_name])
.change_context(Error)?;
Ok(self.conn.last_insert_rowid())
}
/// Get image by ID
pub fn get_image(&self, image_id: i64) -> Result<Option<ImageRecord>> {
let mut stmt = self
.conn
.prepare("SELECT id, file_path, width, height, created_at FROM images WHERE id = ?1")
.change_context(Error)?;
let result = stmt
.query_row(params![image_id], |row| {
Ok(ImageRecord {
id: row.get(0)?,
file_path: row.get(1)?,
width: row.get(2)?,
height: row.get(3)?,
created_at: row.get(4)?,
})
})
.optional()
.change_context(Error)?;
Ok(result)
}
/// Get all faces for an image
pub fn get_faces_for_image(&self, image_id: i64) -> Result<Vec<FaceRecord>> {
let mut stmt = self
.conn
.prepare(
r#"
SELECT id, image_id, bbox_x1, bbox_y1, bbox_x2, bbox_y2, confidence, created_at
FROM faces WHERE image_id = ?1
"#,
)
.change_context(Error)?;
let face_iter = stmt
.query_map(params![image_id], |row| {
Ok(FaceRecord {
id: row.get(0)?,
image_id: row.get(1)?,
bbox_x1: row.get(2)?,
bbox_y1: row.get(3)?,
bbox_x2: row.get(4)?,
bbox_y2: row.get(5)?,
confidence: row.get(6)?,
created_at: row.get(7)?,
})
})
.change_context(Error)?;
let mut faces = Vec::new();
for face in face_iter {
faces.push(face.change_context(Error)?);
}
Ok(faces)
}
/// Get landmarks for a face
pub fn get_landmarks(&self, face_id: i64) -> Result<Option<LandmarkRecord>> {
let mut stmt = self
.conn
.prepare(
r#"
SELECT id, face_id, left_eye_x, left_eye_y, right_eye_x, right_eye_y,
nose_x, nose_y, left_mouth_x, left_mouth_y, right_mouth_x, right_mouth_y
FROM landmarks WHERE face_id = ?1
"#,
)
.change_context(Error)?;
let result = stmt
.query_row(params![face_id], |row| {
Ok(LandmarkRecord {
id: row.get(0)?,
face_id: row.get(1)?,
left_eye_x: row.get(2)?,
left_eye_y: row.get(3)?,
right_eye_x: row.get(4)?,
right_eye_y: row.get(5)?,
nose_x: row.get(6)?,
nose_y: row.get(7)?,
left_mouth_x: row.get(8)?,
left_mouth_y: row.get(9)?,
right_mouth_x: row.get(10)?,
right_mouth_y: row.get(11)?,
})
})
.optional()
.change_context(Error)?;
Ok(result)
}
/// Get embeddings for a face
pub fn get_embeddings(&self, face_id: i64) -> Result<Vec<EmbeddingRecord>> {
let mut stmt = self
.conn
.prepare(
"SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE face_id = ?1",
)
.change_context(Error)?;
let embedding_iter = stmt
.query_map(params![face_id], |row| {
let embedding_bytes: Vec<u8> = row.get(2)?;
let embedding: ndarray::Array1<f32> = {
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
.change_context(Error)
// .change_context(Error)?
.unwrap();
sf.tensor::<f32, ndarray::Ix1>("embedding")
// .change_context(Error)?
.unwrap()
.to_owned()
};
Ok(EmbeddingRecord {
id: row.get(0)?,
face_id: row.get(1)?,
embedding,
model_name: row.get(3)?,
created_at: row.get(4)?,
})
})
.change_context(Error)?;
let mut embeddings = Vec::new();
for embedding in embedding_iter {
embeddings.push(embedding.change_context(Error)?);
}
Ok(embeddings)
}
pub fn get_image_for_face(&self, face_id: i64) -> Result<Option<ImageRecord>> {
let mut stmt = self
.conn
.prepare(
r#"
SELECT images.id, images.file_path, images.width, images.height, images.created_at
FROM images
JOIN faces ON faces.image_id = images.id
WHERE faces.id = ?1
"#,
)
.change_context(Error)?;
let result = stmt
.query_row(params![face_id], |row| {
Ok(ImageRecord {
id: row.get(0)?,
file_path: row.get(1)?,
width: row.get(2)?,
height: row.get(3)?,
created_at: row.get(4)?,
})
})
.optional()
.change_context(Error)?;
Ok(result)
}
/// Get database statistics
pub fn get_stats(&self) -> Result<(usize, usize, usize, usize)> {
let images: usize = self
.conn
.query_row("SELECT COUNT(*) FROM images", [], |row| row.get(0))
.change_context(Error)?;
let faces: usize = self
.conn
.query_row("SELECT COUNT(*) FROM faces", [], |row| row.get(0))
.change_context(Error)?;
let landmarks: usize = self
.conn
.query_row("SELECT COUNT(*) FROM landmarks", [], |row| row.get(0))
.change_context(Error)?;
let embeddings: usize = self
.conn
.query_row("SELECT COUNT(*) FROM embeddings", [], |row| row.get(0))
.change_context(Error)?;
Ok((images, faces, landmarks, embeddings))
}
/// Find similar faces based on cosine similarity of embeddings
/// Return ids and similarity scores of similar faces
pub fn find_similar_faces(
&self,
embedding: &ndarray::Array1<f32>,
threshold: f32,
limit: usize,
) -> Result<Vec<(i64, f32)>> {
// Serialize the query embedding to bytes
let embedding_bytes =
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
.change_context(Error)?
.serialize()
.change_context(Error)?;
let mut stmt = self
.conn
.prepare(
r#" SELECT face_id, cosine_similarity(?1, embedding) as similarity
FROM embeddings
WHERE cosine_similarity(?1, embedding) >= ?2
ORDER BY similarity DESC
LIMIT ?3"#,
)
.change_context(Error)?;
let result = stmt
.query_map(params![embedding_bytes, threshold, limit], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
})
.change_context(Error)?
.map(|r| r.change_context(Error))
.collect::<Result<Vec<_>>>()?;
// let mut results = Vec::new();
// for result in result_iter {
// results.push(result.change_context(Error)?);
// }
Ok(result)
}
pub fn query_similarity(&self, embedding: &ndarray::Array1<f32>) {
let embedding_bytes =
ndarray_safetensors::SafeArrays::from_ndarrays([("embedding", embedding.view())])
.change_context(Error)
.unwrap()
.serialize()
.change_context(Error)
.unwrap();
let mut stmt = self
.conn
.prepare(
r#"
SELECT face_id,
cosine_similarity(?1, embedding)
FROM embeddings
"#,
)
.change_context(Error)
.unwrap();
let result_iter = stmt
.query_map(params![embedding_bytes], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, f32>(1)?))
})
.change_context(Error)
.unwrap();
for result in result_iter {
println!("{:?}", result);
}
}
/// Get all embeddings for a specific model
pub fn get_all_embeddings(&self, model_name: Option<&str>) -> Result<Vec<EmbeddingRecord>> {
let mut embeddings = Vec::new();
if let Some(model) = model_name {
let mut stmt = self.conn.prepare(
"SELECT id, face_id, embedding, model_name, created_at FROM embeddings WHERE model_name = ?1"
).change_context(Error)?;
let embedding_iter = stmt
.query_map(params![model], |row| {
let embedding_bytes: Vec<u8> = row.get(2)?;
let embedding: ndarray::Array1<f32> = {
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
.change_context(Error)
.unwrap();
sf.tensor::<f32, ndarray::Ix1>("embedding")
.unwrap()
.to_owned()
};
Ok(EmbeddingRecord {
id: row.get(0)?,
face_id: row.get(1)?,
embedding,
model_name: row.get(3)?,
created_at: row.get(4)?,
})
})
.change_context(Error)?;
for embedding in embedding_iter {
embeddings.push(embedding.change_context(Error)?);
}
} else {
let mut stmt = self
.conn
.prepare("SELECT id, face_id, embedding, model_name, created_at FROM embeddings")
.change_context(Error)?;
let embedding_iter = stmt
.query_map([], |row| {
let embedding_bytes: Vec<u8> = row.get(2)?;
let embedding: ndarray::Array1<f32> = {
let sf = ndarray_safetensors::SafeArraysView::from_bytes(&embedding_bytes)
.change_context(Error)
.unwrap();
sf.tensor::<f32, ndarray::Ix1>("embedding")
.unwrap()
.to_owned()
};
Ok(EmbeddingRecord {
id: row.get(0)?,
face_id: row.get(1)?,
embedding,
model_name: row.get(3)?,
created_at: row.get(4)?,
})
})
.change_context(Error)?;
for embedding in embedding_iter {
embeddings.push(embedding.change_context(Error)?);
}
}
Ok(embeddings)
}
}
fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
use rusqlite::functions::*;
db.create_scalar_function(
"cosine_similarity",
2,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
move |ctx| {
if ctx.len() != 2 {
return Err(rusqlite::Error::UserFunctionError(
"cosine_similarity requires exactly 2 arguments".into(),
));
}
let array_1 = ctx.get_raw(0).as_blob()?;
let array_2 = ctx.get_raw(1).as_blob()?;
let array_1_st = ndarray_safetensors::SafeArraysView::from_bytes(array_1)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let array_2_st = ndarray_safetensors::SafeArraysView::from_bytes(array_2)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let array_view_1 = array_1_st
.tensor_by_index::<f32, ndarray::Ix1>(0)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let array_view_2 = array_2_st
.tensor_by_index::<f32, ndarray::Ix1>(0)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
let similarity = array_view_1
.cosine_similarity(array_view_2)
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
Ok(similarity)
},
)
.change_context(Error)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_database_creation() -> Result<()> {
let db = FaceDatabase::in_memory()?;
let (images, faces, landmarks, embeddings) = db.get_stats()?;
assert_eq!(images, 0);
assert_eq!(faces, 0);
assert_eq!(landmarks, 0);
assert_eq!(embeddings, 0);
Ok(())
}
#[test]
fn test_store_and_retrieve_image() -> Result<()> {
let db = FaceDatabase::in_memory()?;
let image_id = db.store_image("/path/to/image.jpg", 800, 600)?;
let image = db.get_image(image_id)?.unwrap();
assert_eq!(image.file_path, "/path/to/image.jpg");
assert_eq!(image.width, 800);
assert_eq!(image.height, 600);
Ok(())
}
}

View File

@@ -170,12 +170,13 @@ impl FaceDetectionModelOutput {
let boxes = self.bbox.slice(s![0, .., ..]); let boxes = self.bbox.slice(s![0, .., ..]);
let landmarks_raw = self.landmark.slice(s![0, .., ..]); let landmarks_raw = self.landmark.slice(s![0, .., ..]);
let mut decoded_boxes = Vec::new(); // let mut decoded_boxes = Vec::new();
let mut decoded_landmarks = Vec::new(); // let mut decoded_landmarks = Vec::new();
let mut confidences = Vec::new(); // let mut confidences = Vec::new();
for i in 0..priors.shape()[0] { let (decoded_boxes, decoded_landmarks, confidences) = (0..priors.shape()[0])
if scores[i] > config.threshold { .filter(|&i| scores[i] > config.threshold)
.map(|i| {
let prior = priors.row(i); let prior = priors.row(i);
let loc = boxes.row(i); let loc = boxes.row(i);
let landm = landmarks_raw.row(i); let landm = landmarks_raw.row(i);
@@ -200,16 +201,21 @@ impl FaceDetectionModelOutput {
let mut bbox = let mut bbox =
Aabb2::from_min_max_vertices(Point2::new(xmin, ymin), Point2::new(xmax, ymax)); Aabb2::from_min_max_vertices(Point2::new(xmin, ymin), Point2::new(xmax, ymax));
if config.clamp { if config.clamp {
bbox.component_clamp(0.0, 1.0); bbox = bbox.component_clamp(0.0, 1.0);
} }
decoded_boxes.push(bbox);
// Decode landmarks // Decode landmarks
let mut points = [Point2::new(0.0, 0.0); 5]; let points: [Point2<f32>; 5] = (0..5)
for j in 0..5 { .map(|j| {
points[j].x = prior_cx + landm[j * 2] * var[0] * prior_w; Point2::new(
points[j].y = prior_cy + landm[j * 2 + 1] * var[0] * prior_h; prior_cx + landm[j * 2] * var[0] * prior_w,
} prior_cy + landm[j * 2 + 1] * var[0] * prior_h,
)
})
.collect::<Vec<_>>()
.try_into()
.unwrap();
let landmarks = FaceLandmarks { let landmarks = FaceLandmarks {
left_eye: points[0], left_eye: points[0],
right_eye: points[1], right_eye: points[1],
@@ -217,11 +223,18 @@ impl FaceDetectionModelOutput {
left_mouth: points[3], left_mouth: points[3],
right_mouth: points[4], right_mouth: points[4],
}; };
decoded_landmarks.push(landmarks);
confidences.push(scores[i]);
}
}
(bbox, landmarks, scores[i])
})
.fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut boxes, mut landmarks, mut confs), (bbox, landmark, conf)| {
boxes.push(bbox);
landmarks.push(landmark);
confs.push(conf);
(boxes, landmarks, confs)
},
);
Ok(FaceDetectionProcessedOutput { Ok(FaceDetectionProcessedOutput {
bbox: decoded_boxes, bbox: decoded_boxes,
confidence: confidences, confidence: confidences,
@@ -310,7 +323,7 @@ pub trait FaceDetector {
fn detect_faces( fn detect_faces(
&mut self, &mut self,
image: ndarray::ArrayView3<u8>, image: ndarray::ArrayView3<u8>,
config: FaceDetectionConfig, config: &FaceDetectionConfig,
) -> Result<FaceDetectionOutput> { ) -> Result<FaceDetectionOutput> {
let (height, width, _channels) = image.dim(); let (height, width, _channels) = image.dim();
let output = self let output = self

View File

@@ -11,6 +11,23 @@ pub use facenet::ort::EmbeddingGenerator as OrtEmbeddingGenerator;
use crate::errors::*; use crate::errors::*;
use ndarray::{Array2, ArrayView4}; use ndarray::{Array2, ArrayView4};
pub mod preprocessing {
use ndarray::*;
pub fn preprocess(faces: ArrayView4<u8>) -> Array4<f32> {
let mut owned = faces.as_standard_layout().mapv(|v| v as f32).to_owned();
owned.axis_iter_mut(Axis(0)).for_each(|mut image| {
let mean = image.mean().unwrap_or(0.0);
let std = image.std(0.0);
if std > 0.0 {
image.mapv_inplace(|x| (x - mean) / std);
} else {
image.mapv_inplace(|x| (x - 127.5) / 128.0)
}
});
owned
}
}
/// Common trait for face embedding backends - maintained for backward compatibility /// Common trait for face embedding backends - maintained for backward compatibility
pub trait FaceEmbedder { pub trait FaceEmbedder {
/// Generate embeddings for a batch of face images /// Generate embeddings for a batch of face images

View File

@@ -4,6 +4,7 @@ pub mod ort;
use crate::errors::*; use crate::errors::*;
use error_stack::ResultExt; use error_stack::ResultExt;
use ndarray::{Array1, Array2, ArrayView3, ArrayView4}; use ndarray::{Array1, Array2, ArrayView3, ArrayView4};
use ndarray_math::{CosineSimilarity, EuclideanDistance};
/// Configuration for face embedding processing /// Configuration for face embedding processing
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@@ -32,9 +33,9 @@ impl FaceEmbeddingConfig {
impl Default for FaceEmbeddingConfig { impl Default for FaceEmbeddingConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
input_width: 160, input_width: 320,
input_height: 160, input_height: 320,
normalize: true, normalize: false,
} }
} }
} }
@@ -63,15 +64,14 @@ impl FaceEmbedding {
/// Calculate cosine similarity with another embedding /// Calculate cosine similarity with another embedding
pub fn cosine_similarity(&self, other: &FaceEmbedding) -> f32 { pub fn cosine_similarity(&self, other: &FaceEmbedding) -> f32 {
let dot_product = self.vector.dot(&other.vector); self.vector.cosine_similarity(&other.vector).unwrap_or(0.0)
let norm_self = self.vector.mapv(|x| x * x).sum().sqrt();
let norm_other = other.vector.mapv(|x| x * x).sum().sqrt();
dot_product / (norm_self * norm_other)
} }
/// Calculate Euclidean distance with another embedding /// Calculate Euclidean distance with another embedding
pub fn euclidean_distance(&self, other: &FaceEmbedding) -> f32 { pub fn euclidean_distance(&self, other: &FaceEmbedding) -> f32 {
(&self.vector - &other.vector).mapv(|x| x * x).sum().sqrt() self.vector
.euclidean_distance(other.vector.view())
.unwrap_or(f32::INFINITY)
} }
/// Normalize the embedding vector to unit length /// Normalize the embedding vector to unit length

View File

@@ -64,10 +64,7 @@ impl EmbeddingGenerator {
} }
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> { pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
let tensor = face let tensor = crate::faceembed::preprocessing::preprocess(face);
// .permuted_axes((0, 3, 1, 2))
.as_standard_layout()
.mapv(|x| x as f32);
let shape: [usize; 4] = tensor.dim().into(); let shape: [usize; 4] = tensor.dim().into();
let shape = shape.map(|f| f as i32); let shape = shape.map(|f| f as i32);
let output = self let output = self

View File

@@ -135,10 +135,12 @@ impl EmbeddingGenerator {
pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> { pub fn run_models(&mut self, faces: ArrayView4<u8>) -> crate::errors::Result<Array2<f32>> {
// Convert input from u8 to f32 and normalize to [0, 1] range // Convert input from u8 to f32 and normalize to [0, 1] range
let input_tensor = faces let input_tensor = crate::faceembed::preprocessing::preprocess(faces);
.mapv(|x| x as f32 / 255.0)
.as_standard_layout() // face_array = np.asarray(face_resized, 'float32')
.into_owned(); // mean, std = face_array.mean(), face_array.std()
// face_normalized = (face_array - mean) / std
// let input_tensor = faces.mean()
tracing::trace!("Input tensor shape: {:?}", input_tensor.shape()); tracing::trace!("Input tensor shape: {:?}", input_tensor.shape());

1053
src/gui/app.rs Normal file

File diff suppressed because it is too large Load Diff

569
src/gui/bridge.rs Normal file
View File

@@ -0,0 +1,569 @@
use std::path::PathBuf;
use crate::errors;
use crate::facedet::{FaceDetectionConfig, FaceDetector, retinaface};
use crate::faceembed::facenet;
use crate::gui::app::{ComparisonResult, DetectionResult, ExecutorType};
use bounding_box::Aabb2;
use bounding_box::roi::MultiRoi as _;
use error_stack::ResultExt;
use fast_image_resize::ResizeOptions;
use ndarray::{Array1, Array2, Array3, Array4};
use ndarray_image::ImageToNdarray;
use ndarray_math::CosineSimilarity;
use ndarray_resize::NdFir;
const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../../models/retinaface.mnn");
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../../models/facenet.mnn");
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../../models/retinaface.onnx");
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../../models/facenet.onnx");
pub struct FaceDetectionBridge;
impl FaceDetectionBridge {
pub async fn detect_faces(
image_path: PathBuf,
output_path: Option<PathBuf>,
threshold: f32,
nms_threshold: f32,
executor_type: ExecutorType,
) -> DetectionResult {
let start_time = std::time::Instant::now();
match Self::run_detection_internal(
image_path.clone(),
output_path,
threshold,
nms_threshold,
executor_type,
)
.await
{
Ok((faces_count, processed_image)) => {
let processing_time = start_time.elapsed().as_secs_f64();
DetectionResult::Success {
image_path,
faces_count,
processed_image,
processing_time,
}
}
Err(error) => DetectionResult::Error(error.to_string()),
}
}
pub async fn compare_faces(
image1_path: PathBuf,
image2_path: PathBuf,
threshold: f32,
nms_threshold: f32,
executor_type: ExecutorType,
) -> ComparisonResult {
let start_time = std::time::Instant::now();
match Self::run_comparison_internal(
image1_path,
image2_path,
threshold,
nms_threshold,
executor_type,
)
.await
{
Ok((
image1_faces,
image2_faces,
image1_face_rois,
image2_face_rois,
best_similarity,
)) => {
let processing_time = start_time.elapsed().as_secs_f64();
ComparisonResult::Success {
image1_faces,
image2_faces,
image1_face_rois,
image2_face_rois,
best_similarity,
processing_time,
}
}
Err(error) => ComparisonResult::Error(error.to_string()),
}
}
async fn run_detection_internal(
image_path: PathBuf,
output_path: Option<PathBuf>,
threshold: f32,
nms_threshold: f32,
executor_type: ExecutorType,
) -> Result<(usize, Option<Vec<u8>>), Box<dyn std::error::Error + Send + Sync>> {
// Load the image
let img = image::open(&image_path)?;
let img_rgb = img.to_rgb8();
// Convert to ndarray format
let image_array = img_rgb.as_ndarray()?;
// Create detection configuration
let config = FaceDetectionConfig::default()
.with_threshold(threshold)
.with_nms_threshold(nms_threshold)
.with_input_width(1024)
.with_input_height(1024);
// Create detector and detect faces
let faces = match executor_type {
ExecutorType::MnnCpu => {
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_forward_type(mnn::ForwardType::CPU)
.build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
detector
.detect_faces(image_array.view(), &config)
.map_err(|e| format!("Detection failed: {}", e))?
}
#[cfg(feature = "mnn-metal")]
ExecutorType::MnnMetal => {
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_forward_type(mnn::ForwardType::Metal)
.build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
detector
.detect_faces(image_array.view(), &config)
.map_err(|e| format!("Detection failed: {}", e))?
}
#[cfg(feature = "mnn-coreml")]
ExecutorType::MnnCoreML => {
let mut detector = retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_forward_type(mnn::ForwardType::CoreML)
.build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
detector
.detect_faces(image_array.view(), &config)
.map_err(|e| format!("Detection failed: {}", e))?
}
ExecutorType::OnnxCpu => {
let mut detector = retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.map_err(|e| format!("Failed to create ONNX detector: {}", e))?
.build()
.map_err(|e| format!("Failed to build ONNX detector: {}", e))?;
detector
.detect_faces(image_array.view(), &config)
.map_err(|e| format!("Detection failed: {}", e))?
}
#[cfg(feature = "ort-cuda")]
ExecutorType::OrtCuda => {
use crate::ort_ep::ExecutionProvider;
let ep = ExecutionProvider::CUDA;
let mut detector = retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.map_err(|e| format!("Failed to create ONNX CUDA detector: {}", e))?
.with_execution_providers([ep])
.build()
.map_err(|e| format!("Failed to build ONNX CUDA detector: {}", e))?;
detector
.detect_faces(image_array.view(), &config)
.map_err(|e| format!("CUDA detection failed: {}", e))?
}
};
let faces_count = faces.bbox.len();
// Generate output image with bounding boxes if requested
let processed_image = if output_path.is_some() || true {
// Always generate for GUI display
let mut output_img = img.to_rgb8();
for bbox in &faces.bbox {
let min_point = bbox.min_vertex();
let size = bbox.size();
let rect = imageproc::rect::Rect::at(min_point.x as i32, min_point.y as i32)
.of_size(size.x as u32, size.y as u32);
imageproc::drawing::draw_hollow_rect_mut(
&mut output_img,
rect,
image::Rgb([255, 0, 0]),
);
}
// Convert to bytes for GUI display
let mut buffer = Vec::new();
let mut cursor = std::io::Cursor::new(&mut buffer);
image::DynamicImage::ImageRgb8(output_img.clone())
.write_to(&mut cursor, image::ImageFormat::Png)?;
// Save to file if output path is specified
if let Some(ref output_path) = output_path {
output_img.save(output_path)?;
}
Some(buffer)
} else {
None
};
Ok((faces_count, processed_image))
}
async fn run_comparison_internal(
image1_path: PathBuf,
image2_path: PathBuf,
threshold: f32,
nms_threshold: f32,
executor_type: ExecutorType,
) -> Result<
(usize, usize, Vec<Array3<u8>>, Vec<Array3<u8>>, f32),
Box<dyn std::error::Error + Send + Sync>,
> {
// Create detector and embedder, detect faces and generate embeddings
let (image1_faces, image2_faces, image1_rois, image2_rois, best_similarity) =
match executor_type {
ExecutorType::MnnCpu => {
let mut detector =
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_forward_type(mnn::ForwardType::CPU)
.build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
.with_forward_type(mnn::ForwardType::CPU)
.build()
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
let img_1 = run_detection(
image1_path,
&mut detector,
&mut embedder,
threshold,
nms_threshold,
2,
)?;
let img_2 = run_detection(
image2_path,
&mut detector,
&mut embedder,
threshold,
nms_threshold,
2,
)?;
let image1_rois = img_1.rois;
let image2_rois = img_2.rois;
let image1_bbox_len = img_1.bbox.len();
let image2_bbox_len = img_2.bbox.len();
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
(
image1_bbox_len,
image2_bbox_len,
image1_rois,
image2_rois,
best_similarity,
)
}
#[cfg(feature = "mnn-metal")]
ExecutorType::MnnMetal => {
let mut detector =
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_forward_type(mnn::ForwardType::Metal)
.build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
.with_forward_type(mnn::ForwardType::Metal)
.build()
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
let img_1 = run_detection(
image1_path,
&mut detector,
&mut embedder,
threshold,
nms_threshold,
2,
)?;
let img_2 = run_detection(
image2_path,
&mut detector,
&mut embedder,
threshold,
nms_threshold,
2,
)?;
let image1_rois = img_1.rois;
let image2_rois = img_2.rois;
let image1_bbox_len = img_1.bbox.len();
let image2_bbox_len = img_2.bbox.len();
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
(
image1_bbox_len,
image2_bbox_len,
image1_rois,
image2_rois,
best_similarity,
)
}
#[cfg(feature = "mnn-coreml")]
ExecutorType::MnnCoreML => {
let mut detector =
retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_forward_type(mnn::ForwardType::CoreML)
.build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
let mut embedder = facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
.with_forward_type(mnn::ForwardType::CoreML)
.build()
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
let img_1 = run_detection(
image1_path,
&mut detector,
&mut embedder,
threshold,
nms_threshold,
2,
)?;
let img_2 = run_detection(
image2_path,
&mut detector,
&mut embedder,
threshold,
nms_threshold,
2,
)?;
let image1_rois = img_1.rois;
let image2_rois = img_2.rois;
let image1_bbox_len = img_1.bbox.len();
let image2_bbox_len = img_2.bbox.len();
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
(
image1_bbox_len,
image2_bbox_len,
image1_rois,
image2_rois,
best_similarity,
)
}
ExecutorType::OnnxCpu => unimplemented!("ONNX face comparison not yet implemented"),
#[cfg(feature = "ort-cuda")]
ExecutorType::OrtCuda => {
use crate::ort_ep::ExecutionProvider;
let ep = ExecutionProvider::CUDA;
let mut detector =
retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.map_err(|e| format!("Failed to create MNN detector: {}", e))?
.with_execution_providers([ep])
.build()
.map_err(|e| format!("Failed to build MNN detector: {}", e))?;
let mut embedder =
facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
.map_err(|e| format!("Failed to create MNN embedder: {}", e))?
.with_execution_providers([ep])
.build()
.map_err(|e| format!("Failed to build MNN embedder: {}", e))?;
let img_1 = run_detection(
image1_path,
&mut detector,
&mut embedder,
threshold,
nms_threshold,
2,
)?;
let img_2 = run_detection(
image2_path,
&mut detector,
&mut embedder,
threshold,
nms_threshold,
2,
)?;
let image1_rois = img_1.rois;
let image2_rois = img_2.rois;
let image1_bbox_len = img_1.bbox.len();
let image2_bbox_len = img_2.bbox.len();
let best_similarity = compare_faces(&img_1.embeddings, &img_2.embeddings)?;
(
image1_bbox_len,
image2_bbox_len,
image1_rois,
image2_rois,
best_similarity,
)
}
};
Ok((
image1_faces,
image2_faces,
image1_rois,
image2_rois,
best_similarity,
))
}
}
use crate::errors::Error;
pub fn compare_faces(
faces_1: &[Array1<f32>],
faces_2: &[Array1<f32>],
) -> Result<f32, error_stack::Report<crate::errors::Error>> {
use error_stack::Report;
if faces_1.is_empty() || faces_2.is_empty() {
Err(Report::new(crate::errors::Error))
.attach_printable("One or both images have no detected faces")?;
}
if faces_1.len() != faces_2.len() {
Err(Report::new(crate::errors::Error))
.attach_printable("Face count mismatch between images")?;
}
Ok(faces_1
.iter()
.zip(faces_2)
.flat_map(|(face_1, face_2)| face_1.cosine_similarity(face_2))
.inspect(|v| tracing::info!("Cosine similarity: {}", v))
.map(|v| ordered_float::OrderedFloat(v))
.max()
.map(|v| v.0)
.ok_or(Report::new(Error))?)
}
#[derive(Debug)]
pub struct DetectionOutput {
bbox: Vec<Aabb2<usize>>,
rois: Vec<ndarray::Array3<u8>>,
embeddings: Vec<Array1<f32>>,
}
fn run_detection<D, E>(
image: impl AsRef<std::path::Path>,
retinaface: &mut D,
facenet: &mut E,
threshold: f32,
nms_threshold: f32,
chunk_size: usize,
) -> crate::errors::Result<DetectionOutput>
where
D: crate::facedet::FaceDetector,
E: crate::faceembed::FaceEmbedder,
{
use errors::*;
// Initialize database if requested
let image = image.as_ref();
let image = image::open(image)
.change_context(Error)
.attach_printable(image.to_string_lossy().to_string())?;
let image = image.into_rgb8();
let mut array = image
.into_ndarray()
.change_context(errors::Error)
.attach_printable("Failed to convert image to ndarray")?;
let output = retinaface
.detect_faces(
array.view(),
&FaceDetectionConfig::default()
.with_threshold(threshold)
.with_nms_threshold(nms_threshold),
)
.change_context(errors::Error)
.attach_printable("Failed to detect faces")?;
let bboxes = output
.bbox
.iter()
.inspect(|bbox| tracing::info!("Raw bbox: {:?}", bbox))
.map(|bbox| bbox.as_::<f32>().scale_uniform(1.30).as_::<usize>())
.inspect(|bbox| tracing::info!("Padded bbox: {:?}", bbox))
.collect_vec();
for bbox in &bboxes {
tracing::info!("Detected face: {:?}", bbox);
use bounding_box::draw::*;
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
}
use itertools::Itertools;
let face_rois = array
.view()
.multi_roi(&bboxes)
.change_context(Error)?
.into_iter()
.map(|roi| {
roi.as_standard_layout()
.fast_resize(224, 224, &ResizeOptions::default())
.change_context(Error)
})
.collect::<Result<Vec<_>>>()?;
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
let embeddings: Vec<Array1<f32>> = face_roi_views
.chunks(chunk_size)
.map(|chunk| {
tracing::info!("Processing chunk of size: {}", chunk.len());
let og_size = chunk.len();
if chunk.len() < chunk_size {
tracing::warn!("Chunk size is less than 8, padding with zeros");
let zeros = Array3::zeros((224, 224, 3));
let chunk: Vec<_> = chunk
.iter()
.map(|arr| arr.reborrow())
.chain(core::iter::repeat(zeros.view()))
.take(chunk_size)
.collect();
let face_rois: Array4<u8> = ndarray::stack(ndarray::Axis(0), chunk.as_slice())
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
Ok((output, og_size))
} else {
let face_rois: Array4<u8> = ndarray::stack(ndarray::Axis(0), chunk)
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
Ok((output, og_size))
}
})
.collect::<Result<Vec<(Array2<f32>, usize)>>>()?
.into_iter()
.map(|(chunk, size): (Array2<f32>, usize)| {
use itertools::Itertools;
chunk
.rows()
.into_iter()
.take(size)
.map(|row| row.to_owned())
.collect_vec()
.into_iter()
})
.flatten()
.collect::<Vec<Array1<f32>>>();
Ok(DetectionOutput {
bbox: bboxes,
rois: face_rois,
embeddings,
})
}

5
src/gui/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
pub mod app;
pub mod bridge;
pub use app::{FaceDetectorApp, Message, run};
pub use bridge::FaceDetectionBridge;

View File

@@ -1,5 +0,0 @@
// pub struct Image {
// pub width: u32,
// pub height: u32,
// pub data: Vec<u8>,
// }

View File

@@ -1,7 +1,7 @@
pub mod database;
pub mod errors; pub mod errors;
pub mod facedet; pub mod facedet;
pub mod faceembed; pub mod faceembed;
pub mod image; pub mod gui;
pub mod ort_ep; pub mod ort_ep;
pub use errors::*;
use errors::*;

View File

@@ -1,175 +0,0 @@
mod cli;
mod errors;
use bounding_box::roi::MultiRoi;
use detector::{facedet, facedet::FaceDetectionConfig, faceembed};
use errors::*;
use fast_image_resize::ResizeOptions;
use ndarray::*;
use ndarray_image::*;
use ndarray_resize::NdFir;
const RETINAFACE_MODEL_MNN: &[u8] = include_bytes!("../models/retinaface.mnn");
const FACENET_MODEL_MNN: &[u8] = include_bytes!("../models/facenet.mnn");
const RETINAFACE_MODEL_ONNX: &[u8] = include_bytes!("../models/retinaface.onnx");
const FACENET_MODEL_ONNX: &[u8] = include_bytes!("../models/facenet.onnx");
pub fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter("error")
.with_thread_ids(true)
.with_thread_names(true)
.with_target(false)
.init();
let args = <cli::Cli as clap::Parser>::parse();
match args.cmd {
cli::SubCommand::Detect(detect) => {
// Choose backend based on executor type (defaulting to MNN for backward compatibility)
let executor = detect
.mnn_forward_type
.map(|f| cli::Executor::Mnn(f))
.or_else(|| {
if detect.ort_execution_provider.is_empty() {
None
} else {
Some(cli::Executor::Ort(detect.ort_execution_provider.clone()))
}
})
.unwrap_or(cli::Executor::Mnn(mnn::ForwardType::CPU));
match executor {
cli::Executor::Mnn(forward) => {
let retinaface =
facedet::retinaface::mnn::FaceDetection::builder(RETINAFACE_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::mnn::EmbeddingGenerator::builder(FACENET_MODEL_MNN)
.change_context(Error)?
.with_forward_type(forward)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_detection(detect, retinaface, facenet)?;
}
cli::Executor::Ort(ep) => {
let retinaface =
facedet::retinaface::ort::FaceDetection::builder(RETINAFACE_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(&ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face detection model")?;
let facenet =
faceembed::facenet::ort::EmbeddingGenerator::builder(FACENET_MODEL_ONNX)
.change_context(Error)?
.with_execution_providers(ep)
.build()
.change_context(errors::Error)
.attach_printable("Failed to create face embedding model")?;
run_detection(detect, retinaface, facenet)?;
}
}
}
cli::SubCommand::List(list) => {
println!("List: {:?}", list);
}
cli::SubCommand::Completions { shell } => {
cli::Cli::completions(shell);
}
}
Ok(())
}
fn run_detection<D, E>(detect: cli::Detect, mut retinaface: D, mut facenet: E) -> Result<()>
where
D: facedet::FaceDetector,
E: faceembed::FaceEmbedder,
{
let image = image::open(&detect.image)
.change_context(Error)
.attach_printable(detect.image.to_string_lossy().to_string())?;
let image = image.into_rgb8();
let mut array = image
.into_ndarray()
.change_context(errors::Error)
.attach_printable("Failed to convert image to ndarray")?;
let output = retinaface
.detect_faces(
array.view(),
FaceDetectionConfig::default()
.with_threshold(detect.threshold)
.with_nms_threshold(detect.nms_threshold),
)
.change_context(errors::Error)
.attach_printable("Failed to detect faces")?;
for bbox in &output.bbox {
tracing::info!("Detected face: {:?}", bbox);
use bounding_box::draw::*;
array.draw(bbox, color::palette::css::GREEN_YELLOW.to_rgba8(), 1);
}
let face_rois = array
.view()
.multi_roi(&output.bbox)
.change_context(Error)?
.into_iter()
// .inspect(|f| {
// tracing::info!("Face ROI shape before resize: {:?}", f.dim());
// })
.map(|roi| {
roi.as_standard_layout()
.fast_resize(160, 160, &ResizeOptions::default())
.change_context(Error)
})
// .inspect(|f| {
// f.as_ref().inspect(|f| {
// tracing::info!("Face ROI shape after resize: {:?}", f.dim());
// });
// })
.collect::<Result<Vec<_>>>()?;
let face_roi_views = face_rois.iter().map(|roi| roi.view()).collect::<Vec<_>>();
let chunk_size = detect.batch_size;
let embeddings = face_roi_views
.chunks(chunk_size)
.map(|chunk| {
tracing::info!("Processing chunk of size: {}", chunk.len());
if chunk.len() < chunk_size {
tracing::warn!("Chunk size is less than 8, padding with zeros");
let zeros = Array3::zeros((160, 160, 3));
let zero_array = core::iter::repeat(zeros.view())
.take(chunk_size)
.collect::<Vec<_>>();
let face_rois: Array4<u8> = ndarray::stack(Axis(0), zero_array.as_slice())
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
Ok(output)
} else {
let face_rois: Array4<u8> = ndarray::stack(Axis(0), chunk)
.change_context(errors::Error)
.attach_printable("Failed to stack rois together")?;
let output = facenet.run_models(face_rois.view()).change_context(Error)?;
Ok(output)
}
})
.collect::<Result<Vec<Array2<f32>>>>()?;
let v = array.view();
if let Some(output) = detect.output {
let image: image::RgbImage = v
.to_image()
.change_context(errors::Error)
.attach_printable("Failed to convert ndarray to image")?;
image
.save(output)
.change_context(errors::Error)
.attach_printable("Failed to save output image")?;
}
Ok(())
}

View File

@@ -13,7 +13,7 @@ use ort::execution_providers::TensorRTExecutionProvider;
use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch}; use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch};
/// Supported execution providers for ONNX Runtime /// Supported execution providers for ONNX Runtime
#[derive(Debug, Clone)] #[derive(Debug, Copy, Clone)]
pub enum ExecutionProvider { pub enum ExecutionProvider {
/// CPU execution provider (always available) /// CPU execution provider (always available)
CPU, CPU,