feat(ndarray-safetensors): add tensor_by_index method for SafeArraysView
This commit is contained in:
@@ -114,6 +114,22 @@ impl<'a> SafeArraysView<'a> {
|
||||
.map(|array_view| array_view.into_dimensionality::<Dim>())??)
|
||||
}
|
||||
|
||||
pub fn tensor_by_index<T: STDtype, Dim: ndarray::Dimension>(
|
||||
&self,
|
||||
index: usize,
|
||||
) -> Result<ndarray::ArrayView<'a, T, Dim>> {
|
||||
self.tensors
|
||||
.iter()
|
||||
.nth(index)
|
||||
.ok_or(SafeTensorError::TensorNotFound(format!(
|
||||
"Index {} out of bounds",
|
||||
index
|
||||
)))
|
||||
.map(|(_, tensor)| tensor_view_to_array_view(tensor))?
|
||||
.map(|array_view| array_view.into_dimensionality::<Dim>())?
|
||||
.map_err(SafeTensorError::NdarrayShapeError)
|
||||
}
|
||||
|
||||
/// Get an iterator over tensor names
|
||||
pub fn names(&self) -> std::vec::IntoIter<&str> {
|
||||
self.tensors.names().into_iter()
|
||||
|
||||
Reference in New Issue
Block a user