Coverage for src/hdmf/common/sparse.py: 100%
42 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-07-25 05:02 +0000
« prev ^ index » next coverage.py v7.2.5, created at 2023-07-25 05:02 +0000
1import scipy.sparse as sps
2from . import register_class
3from ..container import Container
4from ..utils import docval, popargs, to_uint_array, get_data_shape, AllowPositional
7@register_class('CSRMatrix')
8class CSRMatrix(Container):
10 @docval({'name': 'data', 'type': (sps.csr_matrix, 'array_data'),
11 'doc': 'the data to use for this CSRMatrix or CSR data array.'
12 'If passing CSR data array, *indices*, *indptr*, and *shape* must also be provided'},
13 {'name': 'indices', 'type': 'array_data', 'doc': 'CSR index array', 'default': None},
14 {'name': 'indptr', 'type': 'array_data', 'doc': 'CSR index pointer array', 'default': None},
15 {'name': 'shape', 'type': 'array_data', 'doc': 'the shape of the matrix', 'default': None},
16 {'name': 'name', 'type': str, 'doc': 'the name to use for this when storing', 'default': 'csr_matrix'},
17 allow_positional=AllowPositional.WARNING)
18 def __init__(self, **kwargs):
19 data, indices, indptr, shape = popargs('data', 'indices', 'indptr', 'shape', kwargs)
20 super().__init__(**kwargs)
21 if not isinstance(data, sps.csr_matrix):
22 temp_shape = get_data_shape(data)
23 temp_ndim = len(temp_shape)
24 if temp_ndim == 2:
25 data = sps.csr_matrix(data)
26 elif temp_ndim == 1:
27 if any(_ is None for _ in (indptr, indices, shape)):
28 raise ValueError("Must specify 'indptr', 'indices', and 'shape' arguments when passing data array.")
29 indptr = self.__check_arr(indptr, 'indptr')
30 indices = self.__check_arr(indices, 'indices')
31 shape = self.__check_arr(shape, 'shape')
32 if len(shape) != 2:
33 raise ValueError("'shape' argument must specify two and only two dimensions.")
34 data = sps.csr_matrix((data, indices, indptr), shape=shape)
35 else:
36 raise ValueError("'data' argument cannot be ndarray of dimensionality > 2.")
37 self.__data = data
39 @staticmethod
40 def __check_arr(ar, arg):
41 try:
42 ar = to_uint_array(ar)
43 except ValueError as ve:
44 raise ValueError("Cannot convert '%s' to an array of unsigned integers." % arg) from ve
45 if ar.ndim != 1:
46 raise ValueError("'%s' must be a 1D array of unsigned integers." % arg)
47 return ar
49 def __getattr__(self, val):
50 # NOTE: this provides access to self.data, self.indices, self.indptr, self.shape
51 attr = getattr(self.__data, val)
52 if val in ('indices', 'indptr', 'shape'): # needed because sps.csr_matrix may contain int arrays for these
53 attr = to_uint_array(attr)
54 return attr
56 def to_spmat(self):
57 return self.__data