feat(ndarray-safetensors): add tensor_by_index method for SafeArraysView
This commit is contained in:
@@ -114,6 +114,22 @@ impl<'a> SafeArraysView<'a> {
|
|||||||
.map(|array_view| array_view.into_dimensionality::<Dim>())??)
|
.map(|array_view| array_view.into_dimensionality::<Dim>())??)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn tensor_by_index<T: STDtype, Dim: ndarray::Dimension>(
|
||||||
|
&self,
|
||||||
|
index: usize,
|
||||||
|
) -> Result<ndarray::ArrayView<'a, T, Dim>> {
|
||||||
|
self.tensors
|
||||||
|
.iter()
|
||||||
|
.nth(index)
|
||||||
|
.ok_or(SafeTensorError::TensorNotFound(format!(
|
||||||
|
"Index {} out of bounds",
|
||||||
|
index
|
||||||
|
)))
|
||||||
|
.map(|(_, tensor)| tensor_view_to_array_view(tensor))?
|
||||||
|
.map(|array_view| array_view.into_dimensionality::<Dim>())?
|
||||||
|
.map_err(SafeTensorError::NdarrayShapeError)
|
||||||
|
}
|
||||||
|
|
||||||
/// Get an iterator over tensor names
|
/// Get an iterator over tensor names
|
||||||
pub fn names(&self) -> std::vec::IntoIter<&str> {
|
pub fn names(&self) -> std::vec::IntoIter<&str> {
|
||||||
self.tensors.names().into_iter()
|
self.tensors.names().into_iter()
|
||||||
|
|||||||
@@ -551,10 +551,10 @@ fn add_sqlite_cosine_similarity(db: &Connection) -> Result<()> {
|
|||||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||||
|
|
||||||
let array_view_1 = array_1_st
|
let array_view_1 = array_1_st
|
||||||
.tensor::<f32, ndarray::Ix1>("embedding")
|
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
||||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||||
let array_view_2 = array_2_st
|
let array_view_2 = array_2_st
|
||||||
.tensor::<f32, ndarray::Ix1>("embedding")
|
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
||||||
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
|
||||||
|
|
||||||
let similarity = array_view_1
|
let similarity = array_view_1
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ where
|
|||||||
.attach_printable("Failed to detect faces")?;
|
.attach_printable("Failed to detect faces")?;
|
||||||
|
|
||||||
// Store image and face detections in database if requested
|
// Store image and face detections in database if requested
|
||||||
let (image_id, face_ids) = if let Some(ref database) = db {
|
let (_image_id, face_ids) = if let Some(ref database) = db {
|
||||||
let image_path = detect.image.to_string_lossy();
|
let image_path = detect.image.to_string_lossy();
|
||||||
let img_id = database
|
let img_id = database
|
||||||
.store_image(&image_path, image_width, image_height)
|
.store_image(&image_path, image_width, image_height)
|
||||||
|
|||||||
Reference in New Issue
Block a user