diff --git a/ndarray-safetensors/src/lib.rs b/ndarray-safetensors/src/lib.rs index 5a17d64..ab7e96d 100644 --- a/ndarray-safetensors/src/lib.rs +++ b/ndarray-safetensors/src/lib.rs @@ -114,6 +114,22 @@ impl<'a> SafeArraysView<'a> { .map(|array_view| array_view.into_dimensionality::())??) } + pub fn tensor_by_index( + &self, + index: usize, + ) -> Result> { + self.tensors + .iter() + .nth(index) + .ok_or(SafeTensorError::TensorNotFound(format!( + "Index {} out of bounds", + index + ))) + .map(|(_, tensor)| tensor_view_to_array_view(tensor))? + .map(|array_view| array_view.into_dimensionality::())? + .map_err(SafeTensorError::NdarrayShapeError) + } + /// Get an iterator over tensor names pub fn names(&self) -> std::vec::IntoIter<&str> { self.tensors.names().into_iter() diff --git a/src/database.rs b/src/database.rs index 6338cce..35211e7 100644 --- a/src/database.rs +++ b/src/database.rs @@ -551,10 +551,10 @@ fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> { .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?; let array_view_1 = array_1_st - .tensor::("embedding") + .tensor_by_index::(0) .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?; let array_view_2 = array_2_st - .tensor::("embedding") + .tensor_by_index::(0) .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?; let similarity = array_view_1 diff --git a/src/main.rs b/src/main.rs index c661f14..387115c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -130,7 +130,7 @@ where .attach_printable("Failed to detect faces")?; // Store image and face detections in database if requested - let (image_id, face_ids) = if let Some(ref database) = db { + let (_image_id, face_ids) = if let Some(ref database) = db { let image_path = detect.image.to_string_lossy(); let img_id = database .store_image(&image_path, image_width, image_height)