feat(ndarray-safetensors): add tensor_by_index method for SafeArraysView
Some checks failed
build / checks-matrix (push) Successful in 19m24s
build / codecov (push) Failing after 19m27s
docs / docs (push) Failing after 28m51s
build / checks-build (push) Has been cancelled

This commit is contained in:
uttarayan21
2025-08-20 16:05:18 +05:30
parent 97f64e7e10
commit f8122892e0
3 changed files with 19 additions and 3 deletions

View File

@@ -114,6 +114,22 @@ impl<'a> SafeArraysView<'a> {
.map(|array_view| array_view.into_dimensionality::<Dim>())??)
}
pub fn tensor_by_index<T: STDtype, Dim: ndarray::Dimension>(
&self,
index: usize,
) -> Result<ndarray::ArrayView<'a, T, Dim>> {
self.tensors
.iter()
.nth(index)
.ok_or(SafeTensorError::TensorNotFound(format!(
"Index {} out of bounds",
index
)))
.map(|(_, tensor)| tensor_view_to_array_view(tensor))?
.map(|array_view| array_view.into_dimensionality::<Dim>())?
.map_err(SafeTensorError::NdarrayShapeError)
}
/// Get an iterator over tensor names
pub fn names(&self) -> std::vec::IntoIter<&str> {
self.tensors.names().into_iter()