feat: Added stuff
Some checks failed
build / checks-matrix (push) Successful in 19m21s
build / codecov (push) Failing after 19m19s
docs / docs (push) Failing after 28m51s
build / checks-build (push) Has been cancelled

This commit is contained in:
uttarayan21
2025-08-19 20:52:29 +05:30
parent 33798467ba
commit d8bf68dfed
5 changed files with 15 additions and 67 deletions

6
Cargo.lock generated
View File

@@ -1299,6 +1299,7 @@ dependencies = [
[[package]] [[package]]
name = "mnn" name = "mnn"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#4128b5b40e03c8744fc0e68f6684ef8a2dd971e5"
dependencies = [ dependencies = [
"dunce", "dunce",
"error-stack", "error-stack",
@@ -1312,7 +1313,7 @@ dependencies = [
[[package]] [[package]]
name = "mnn-bridge" name = "mnn-bridge"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#456c53307ff551d8cb8e4e380c7febf7c16ba0ab" source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#4128b5b40e03c8744fc0e68f6684ef8a2dd971e5"
dependencies = [ dependencies = [
"error-stack", "error-stack",
"mnn", "mnn",
@@ -1322,7 +1323,7 @@ dependencies = [
[[package]] [[package]]
name = "mnn-sync" name = "mnn-sync"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#456c53307ff551d8cb8e4e380c7febf7c16ba0ab" source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#4128b5b40e03c8744fc0e68f6684ef8a2dd971e5"
dependencies = [ dependencies = [
"error-stack", "error-stack",
"flume", "flume",
@@ -1334,6 +1335,7 @@ dependencies = [
[[package]] [[package]]
name = "mnn-sys" name = "mnn-sys"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/uttarayan21/mnn-rs?branch=restructure-tensor-type#4128b5b40e03c8744fc0e68f6684ef8a2dd971e5"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bindgen", "bindgen",

View File

@@ -5,15 +5,12 @@ members = ["ndarray-image", "ndarray-resize", ".", "bounding-box"]
version = "0.1.0" version = "0.1.0"
edition = "2024" edition = "2024"
[patch."https://github.com/uttarayan21/mnn-rs"]
mnn = { path = "/Users/fs0c131y/Projects/aftershoot/mnn-rs" }
[workspace.dependencies] [workspace.dependencies]
ndarray-image = { path = "ndarray-image" } ndarray-image = { path = "ndarray-image" }
ndarray-resize = { path = "ndarray-resize" } ndarray-resize = { path = "ndarray-resize" }
mnn = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.2.0", features = [ mnn = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.2.0", features = [
"metal", # "metal",
"coreml", # "coreml",
"tracing", "tracing",
], branch = "restructure-tensor-type" } ], branch = "restructure-tensor-type" }
mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [ mnn-bridge = { git = "https://github.com/uttarayan21/mnn-rs", version = "0.1.0", features = [
@@ -53,7 +50,11 @@ bounding-box = { version = "0.1.0", path = "bounding-box" }
color = "0.3.1" color = "0.3.1"
itertools = "0.14.0" itertools = "0.14.0"
ordered-float = "5.0.0" ordered-float = "5.0.0"
ort = { version = "2.0.0-rc.10", default-features = false, features = [ "std", "tracing", "ndarray"]} ort = { version = "2.0.0-rc.10", default-features = false, features = [
"std",
"tracing",
"ndarray",
] }
[profile.release] [profile.release]
debug = true debug = true
@@ -65,5 +66,7 @@ ort-tensorrt = ["ort/tensorrt"]
ort-tvm = ["ort/tvm"] ort-tvm = ["ort/tvm"]
ort-openvino = ["ort/openvino"] ort-openvino = ["ort/openvino"]
ort-directml = ["ort/directml"] ort-directml = ["ort/directml"]
mnn-metal = ["mnn/metal"]
mnn-coreml = ["mnn/coreml"]
default = ["ort-coreml"] default = []

View File

@@ -52,7 +52,7 @@
mnn = mnn-overlay.packages.${system}.mnn.override { mnn = mnn-overlay.packages.${system}.mnn.override {
src = mnn-src; src = mnn-src;
buildConverter = true; buildConverter = true;
enableMetal = true; enableMetal = pkgs.stdenv.isDarwin;
enableOpencl = true; enableOpencl = true;
}; };
}) })

View File

@@ -61,35 +61,6 @@ impl FaceDetection {
) -> std::result::Result<FaceDetectionBuilder, error_stack::Report<crate::errors::Error>> { ) -> std::result::Result<FaceDetectionBuilder, error_stack::Report<crate::errors::Error>> {
FaceDetectionBuilder::new(model) FaceDetectionBuilder::new(model)
} }
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
let model = std::fs::read(path)
.change_context(Error)
.attach_printable("Failed to read model file")?;
Self::new_from_bytes(&model)
}
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
tracing::info!("Loading face detection model from bytes");
let mut model = mnn::Interpreter::from_bytes(model)
.map_err(|e| e.into_inner())
.change_context(Error)
.attach_printable("Failed to load model from bytes")?;
model.set_session_mode(mnn::SessionMode::Release);
model
.set_cache_file("retinaface.cache", 128)
.change_context(Error)
.attach_printable("Failed to set cache file")?;
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
let sc = mnn::ScheduleConfig::new()
.with_type(mnn::ForwardType::Metal)
.with_backend_config(bc);
tracing::info!("Creating session handle for face detection model");
let handle = mnn_sync::SessionHandle::new(model, sc)
.change_context(Error)
.attach_printable("Failed to create session handle")?;
Ok(FaceDetection { handle })
}
} }
impl FaceDetector for FaceDetection { impl FaceDetector for FaceDetection {

View File

@@ -56,12 +56,6 @@ impl EmbeddingGeneratorBuilder {
impl EmbeddingGenerator { impl EmbeddingGenerator {
const INPUT_NAME: &'static str = "serving_default_input_6:0"; const INPUT_NAME: &'static str = "serving_default_input_6:0";
const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0"; const OUTPUT_NAME: &'static str = "StatefulPartitionedCall:0";
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
let model = std::fs::read(path)
.change_context(Error)
.attach_printable("Failed to read model file")?;
Self::new_from_bytes(&model)
}
pub fn builder<T: AsRef<[u8]>>( pub fn builder<T: AsRef<[u8]>>(
model: T, model: T,
@@ -69,28 +63,6 @@ impl EmbeddingGenerator {
EmbeddingGeneratorBuilder::new(model) EmbeddingGeneratorBuilder::new(model)
} }
pub fn new_from_bytes(model: &[u8]) -> Result<Self> {
tracing::info!("Loading face embedding model from bytes");
let mut model = mnn::Interpreter::from_bytes(model)
.map_err(|e| e.into_inner())
.change_context(Error)
.attach_printable("Failed to load model from bytes")?;
model.set_session_mode(mnn::SessionMode::Release);
model
.set_cache_file("facenet.cache", 128)
.change_context(Error)
.attach_printable("Failed to set cache file")?;
let bc = mnn::BackendConfig::default().with_memory_mode(mnn::MemoryMode::High);
let sc = mnn::ScheduleConfig::new()
.with_type(mnn::ForwardType::Metal)
.with_backend_config(bc);
tracing::info!("Creating session handle for face embedding model");
let handle = mnn_sync::SessionHandle::new(model, sc)
.change_context(Error)
.attach_printable("Failed to create session handle")?;
Ok(Self { handle })
}
pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> { pub fn run_models(&self, face: ArrayView4<u8>) -> Result<Array2<f32>> {
let tensor = face let tensor = face
// .permuted_axes((0, 3, 1, 2)) // .permuted_axes((0, 3, 1, 2))