Coverage for src/hdmf/common/sparse.py: 100%

42 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-08-18 20:49 +0000

1import scipy.sparse as sps 

2from . import register_class 

3from ..container import Container 

4from ..utils import docval, popargs, to_uint_array, get_data_shape, AllowPositional 

5 

6 

7@register_class('CSRMatrix') 

8class CSRMatrix(Container): 

9 

10 @docval({'name': 'data', 'type': (sps.csr_matrix, 'array_data'), 

11 'doc': 'the data to use for this CSRMatrix or CSR data array.' 

12 'If passing CSR data array, *indices*, *indptr*, and *shape* must also be provided'}, 

13 {'name': 'indices', 'type': 'array_data', 'doc': 'CSR index array', 'default': None}, 

14 {'name': 'indptr', 'type': 'array_data', 'doc': 'CSR index pointer array', 'default': None}, 

15 {'name': 'shape', 'type': 'array_data', 'doc': 'the shape of the matrix', 'default': None}, 

16 {'name': 'name', 'type': str, 'doc': 'the name to use for this when storing', 'default': 'csr_matrix'}, 

17 allow_positional=AllowPositional.WARNING) 

18 def __init__(self, **kwargs): 

19 data, indices, indptr, shape = popargs('data', 'indices', 'indptr', 'shape', kwargs) 

20 super().__init__(**kwargs) 

21 if not isinstance(data, sps.csr_matrix): 

22 temp_shape = get_data_shape(data) 

23 temp_ndim = len(temp_shape) 

24 if temp_ndim == 2: 

25 data = sps.csr_matrix(data) 

26 elif temp_ndim == 1: 

27 if any(_ is None for _ in (indptr, indices, shape)): 

28 raise ValueError("Must specify 'indptr', 'indices', and 'shape' arguments when passing data array.") 

29 indptr = self.__check_arr(indptr, 'indptr') 

30 indices = self.__check_arr(indices, 'indices') 

31 shape = self.__check_arr(shape, 'shape') 

32 if len(shape) != 2: 

33 raise ValueError("'shape' argument must specify two and only two dimensions.") 

34 data = sps.csr_matrix((data, indices, indptr), shape=shape) 

35 else: 

36 raise ValueError("'data' argument cannot be ndarray of dimensionality > 2.") 

37 self.__data = data 

38 

39 @staticmethod 

40 def __check_arr(ar, arg): 

41 try: 

42 ar = to_uint_array(ar) 

43 except ValueError as ve: 

44 raise ValueError("Cannot convert '%s' to an array of unsigned integers." % arg) from ve 

45 if ar.ndim != 1: 

46 raise ValueError("'%s' must be a 1D array of unsigned integers." % arg) 

47 return ar 

48 

49 def __getattr__(self, val): 

50 # NOTE: this provides access to self.data, self.indices, self.indptr, self.shape 

51 attr = getattr(self.__data, val) 

52 if val in ('indices', 'indptr', 'shape'): # needed because sps.csr_matrix may contain int arrays for these 

53 attr = to_uint_array(attr) 

54 return attr 

55 

56 def to_spmat(self): 

57 return self.__data