Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import functools 

2import operator 

3import sys 

4import warnings 

5import numbers 

6from collections import namedtuple 

7from multiprocessing import Pool 

8import inspect 

9 

10import numpy as np 

11 

12try: 

13 from numpy.random import Generator as Generator 

14except ImportError: 

15 class Generator(): # type: ignore[no-redef] 

16 pass 

17 

18 

19def _valarray(shape, value=np.nan, typecode=None): 

20 """Return an array of all values. 

21 """ 

22 

23 out = np.ones(shape, dtype=bool) * value 

24 if typecode is not None: 

25 out = out.astype(typecode) 

26 if not isinstance(out, np.ndarray): 

27 out = np.asarray(out) 

28 return out 

29 

30 

31def _lazywhere(cond, arrays, f, fillvalue=None, f2=None): 

32 """ 

33 np.where(cond, x, fillvalue) always evaluates x even where cond is False. 

34 This one only evaluates f(arr1[cond], arr2[cond], ...). 

35 For example, 

36 >>> a, b = np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]) 

37 >>> def f(a, b): 

38 return a*b 

39 >>> _lazywhere(a > 2, (a, b), f, np.nan) 

40 array([ nan, nan, 21., 32.]) 

41 

42 Notice, it assumes that all `arrays` are of the same shape, or can be 

43 broadcasted together. 

44 

45 """ 

46 if fillvalue is None: 

47 if f2 is None: 

48 raise ValueError("One of (fillvalue, f2) must be given.") 

49 else: 

50 fillvalue = np.nan 

51 else: 

52 if f2 is not None: 

53 raise ValueError("Only one of (fillvalue, f2) can be given.") 

54 

55 arrays = np.broadcast_arrays(*arrays) 

56 temp = tuple(np.extract(cond, arr) for arr in arrays) 

57 tcode = np.mintypecode([a.dtype.char for a in arrays]) 

58 out = _valarray(np.shape(arrays[0]), value=fillvalue, typecode=tcode) 

59 np.place(out, cond, f(*temp)) 

60 if f2 is not None: 

61 temp = tuple(np.extract(~cond, arr) for arr in arrays) 

62 np.place(out, ~cond, f2(*temp)) 

63 

64 return out 

65 

66 

67def _lazyselect(condlist, choicelist, arrays, default=0): 

68 """ 

69 Mimic `np.select(condlist, choicelist)`. 

70 

71 Notice, it assumes that all `arrays` are of the same shape or can be 

72 broadcasted together. 

73 

74 All functions in `choicelist` must accept array arguments in the order 

75 given in `arrays` and must return an array of the same shape as broadcasted 

76 `arrays`. 

77 

78 Examples 

79 -------- 

80 >>> x = np.arange(6) 

81 >>> np.select([x <3, x > 3], [x**2, x**3], default=0) 

82 array([ 0, 1, 4, 0, 64, 125]) 

83 

84 >>> _lazyselect([x < 3, x > 3], [lambda x: x**2, lambda x: x**3], (x,)) 

85 array([ 0., 1., 4., 0., 64., 125.]) 

86 

87 >>> a = -np.ones_like(x) 

88 >>> _lazyselect([x < 3, x > 3], 

89 ... [lambda x, a: x**2, lambda x, a: a * x**3], 

90 ... (x, a), default=np.nan) 

91 array([ 0., 1., 4., nan, -64., -125.]) 

92 

93 """ 

94 arrays = np.broadcast_arrays(*arrays) 

95 tcode = np.mintypecode([a.dtype.char for a in arrays]) 

96 out = _valarray(np.shape(arrays[0]), value=default, typecode=tcode) 

97 for index in range(len(condlist)): 

98 func, cond = choicelist[index], condlist[index] 

99 if np.all(cond is False): 

100 continue 

101 cond, _ = np.broadcast_arrays(cond, arrays[0]) 

102 temp = tuple(np.extract(cond, arr) for arr in arrays) 

103 np.place(out, cond, func(*temp)) 

104 return out 

105 

106 

107def _aligned_zeros(shape, dtype=float, order="C", align=None): 

108 """Allocate a new ndarray with aligned memory. 

109 

110 Primary use case for this currently is working around a f2py issue 

111 in NumPy 1.9.1, where dtype.alignment is such that np.zeros() does 

112 not necessarily create arrays aligned up to it. 

113 

114 """ 

115 dtype = np.dtype(dtype) 

116 if align is None: 

117 align = dtype.alignment 

118 if not hasattr(shape, '__len__'): 

119 shape = (shape,) 

120 size = functools.reduce(operator.mul, shape) * dtype.itemsize 

121 buf = np.empty(size + align + 1, np.uint8) 

122 offset = buf.__array_interface__['data'][0] % align 

123 if offset != 0: 

124 offset = align - offset 

125 # Note: slices producing 0-size arrays do not necessarily change 

126 # data pointer --- so we use and allocate size+1 

127 buf = buf[offset:offset+size+1][:-1] 

128 data = np.ndarray(shape, dtype, buf, order=order) 

129 data.fill(0) 

130 return data 

131 

132 

133def _prune_array(array): 

134 """Return an array equivalent to the input array. If the input 

135 array is a view of a much larger array, copy its contents to a 

136 newly allocated array. Otherwise, return the input unchanged. 

137 """ 

138 if array.base is not None and array.size < array.base.size // 2: 

139 return array.copy() 

140 return array 

141 

142 

143def prod(iterable): 

144 """ 

145 Product of a sequence of numbers. 

146 

147 Faster than np.prod for short lists like array shapes, and does 

148 not overflow if using Python integers. 

149 """ 

150 product = 1 

151 for x in iterable: 

152 product *= x 

153 return product 

154 

155 

156class DeprecatedImport(object): 

157 """ 

158 Deprecated import with redirection and warning. 

159 

160 Examples 

161 -------- 

162 Suppose you previously had in some module:: 

163 

164 from foo import spam 

165 

166 If this has to be deprecated, do:: 

167 

168 spam = DeprecatedImport("foo.spam", "baz") 

169 

170 to redirect users to use "baz" module instead. 

171 

172 """ 

173 

174 def __init__(self, old_module_name, new_module_name): 

175 self._old_name = old_module_name 

176 self._new_name = new_module_name 

177 __import__(self._new_name) 

178 self._mod = sys.modules[self._new_name] 

179 

180 def __dir__(self): 

181 return dir(self._mod) 

182 

183 def __getattr__(self, name): 

184 warnings.warn("Module %s is deprecated, use %s instead" 

185 % (self._old_name, self._new_name), 

186 DeprecationWarning) 

187 return getattr(self._mod, name) 

188 

189 

190# copy-pasted from scikit-learn utils/validation.py 

191def check_random_state(seed): 

192 """Turn seed into a np.random.RandomState instance 

193 

194 If seed is None (or np.random), return the RandomState singleton used 

195 by np.random. 

196 If seed is an int, return a new RandomState instance seeded with seed. 

197 If seed is already a RandomState instance, return it. 

198 If seed is a new-style np.random.Generator, return it. 

199 Otherwise, raise ValueError. 

200 """ 

201 if seed is None or seed is np.random: 

202 return np.random.mtrand._rand 

203 if isinstance(seed, (numbers.Integral, np.integer)): 

204 return np.random.RandomState(seed) 

205 if isinstance(seed, np.random.RandomState): 

206 return seed 

207 try: 

208 # Generator is only available in numpy >= 1.17 

209 if isinstance(seed, np.random.Generator): 

210 return seed 

211 except AttributeError: 

212 pass 

213 raise ValueError('%r cannot be used to seed a numpy.random.RandomState' 

214 ' instance' % seed) 

215 

216 

217def _asarray_validated(a, check_finite=True, 

218 sparse_ok=False, objects_ok=False, mask_ok=False, 

219 as_inexact=False): 

220 """ 

221 Helper function for SciPy argument validation. 

222 

223 Many SciPy linear algebra functions do support arbitrary array-like 

224 input arguments. Examples of commonly unsupported inputs include 

225 matrices containing inf/nan, sparse matrix representations, and 

226 matrices with complicated elements. 

227 

228 Parameters 

229 ---------- 

230 a : array_like 

231 The array-like input. 

232 check_finite : bool, optional 

233 Whether to check that the input matrices contain only finite numbers. 

234 Disabling may give a performance gain, but may result in problems 

235 (crashes, non-termination) if the inputs do contain infinities or NaNs. 

236 Default: True 

237 sparse_ok : bool, optional 

238 True if scipy sparse matrices are allowed. 

239 objects_ok : bool, optional 

240 True if arrays with dype('O') are allowed. 

241 mask_ok : bool, optional 

242 True if masked arrays are allowed. 

243 as_inexact : bool, optional 

244 True to convert the input array to a np.inexact dtype. 

245 

246 Returns 

247 ------- 

248 ret : ndarray 

249 The converted validated array. 

250 

251 """ 

252 if not sparse_ok: 

253 import scipy.sparse 

254 if scipy.sparse.issparse(a): 

255 msg = ('Sparse matrices are not supported by this function. ' 

256 'Perhaps one of the scipy.sparse.linalg functions ' 

257 'would work instead.') 

258 raise ValueError(msg) 

259 if not mask_ok: 

260 if np.ma.isMaskedArray(a): 

261 raise ValueError('masked arrays are not supported') 

262 toarray = np.asarray_chkfinite if check_finite else np.asarray 

263 a = toarray(a) 

264 if not objects_ok: 

265 if a.dtype is np.dtype('O'): 

266 raise ValueError('object arrays are not supported') 

267 if as_inexact: 

268 if not np.issubdtype(a.dtype, np.inexact): 

269 a = toarray(a, dtype=np.float_) 

270 return a 

271 

272 

273# Add a replacement for inspect.getfullargspec()/ 

274# The version below is borrowed from Django, 

275# https://github.com/django/django/pull/4846. 

276 

277# Note an inconsistency between inspect.getfullargspec(func) and 

278# inspect.signature(func). If `func` is a bound method, the latter does *not* 

279# list `self` as a first argument, while the former *does*. 

280# Hence, cook up a common ground replacement: `getfullargspec_no_self` which 

281# mimics `inspect.getfullargspec` but does not list `self`. 

282# 

283# This way, the caller code does not need to know whether it uses a legacy 

284# .getfullargspec or a bright and shiny .signature. 

285 

286FullArgSpec = namedtuple('FullArgSpec', 

287 ['args', 'varargs', 'varkw', 'defaults', 

288 'kwonlyargs', 'kwonlydefaults', 'annotations']) 

289 

290def getfullargspec_no_self(func): 

291 """inspect.getfullargspec replacement using inspect.signature. 

292 

293 If func is a bound method, do not list the 'self' parameter. 

294 

295 Parameters 

296 ---------- 

297 func : callable 

298 A callable to inspect 

299 

300 Returns 

301 ------- 

302 fullargspec : FullArgSpec(args, varargs, varkw, defaults, kwonlyargs, 

303 kwonlydefaults, annotations) 

304 

305 NOTE: if the first argument of `func` is self, it is *not*, I repeat 

306 *not*, included in fullargspec.args. 

307 This is done for consistency between inspect.getargspec() under 

308 Python 2.x, and inspect.signature() under Python 3.x. 

309 

310 """ 

311 sig = inspect.signature(func) 

312 args = [ 

313 p.name for p in sig.parameters.values() 

314 if p.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD, 

315 inspect.Parameter.POSITIONAL_ONLY] 

316 ] 

317 varargs = [ 

318 p.name for p in sig.parameters.values() 

319 if p.kind == inspect.Parameter.VAR_POSITIONAL 

320 ] 

321 varargs = varargs[0] if varargs else None 

322 varkw = [ 

323 p.name for p in sig.parameters.values() 

324 if p.kind == inspect.Parameter.VAR_KEYWORD 

325 ] 

326 varkw = varkw[0] if varkw else None 

327 defaults = tuple( 

328 p.default for p in sig.parameters.values() 

329 if (p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and 

330 p.default is not p.empty) 

331 ) or None 

332 kwonlyargs = [ 

333 p.name for p in sig.parameters.values() 

334 if p.kind == inspect.Parameter.KEYWORD_ONLY 

335 ] 

336 kwdefaults = {p.name: p.default for p in sig.parameters.values() 

337 if p.kind == inspect.Parameter.KEYWORD_ONLY and 

338 p.default is not p.empty} 

339 annotations = {p.name: p.annotation for p in sig.parameters.values() 

340 if p.annotation is not p.empty} 

341 return FullArgSpec(args, varargs, varkw, defaults, kwonlyargs, 

342 kwdefaults or None, annotations) 

343 

344 

345class MapWrapper(object): 

346 """ 

347 Parallelisation wrapper for working with map-like callables, such as 

348 `multiprocessing.Pool.map`. 

349 

350 Parameters 

351 ---------- 

352 pool : int or map-like callable 

353 If `pool` is an integer, then it specifies the number of threads to 

354 use for parallelization. If ``int(pool) == 1``, then no parallel 

355 processing is used and the map builtin is used. 

356 If ``pool == -1``, then the pool will utilize all available CPUs. 

357 If `pool` is a map-like callable that follows the same 

358 calling sequence as the built-in map function, then this callable is 

359 used for parallelization. 

360 """ 

361 def __init__(self, pool=1): 

362 self.pool = None 

363 self._mapfunc = map 

364 self._own_pool = False 

365 

366 if callable(pool): 

367 self.pool = pool 

368 self._mapfunc = self.pool 

369 else: 

370 # user supplies a number 

371 if int(pool) == -1: 

372 # use as many processors as possible 

373 self.pool = Pool() 

374 self._mapfunc = self.pool.map 

375 self._own_pool = True 

376 elif int(pool) == 1: 

377 pass 

378 elif int(pool) > 1: 

379 # use the number of processors requested 

380 self.pool = Pool(processes=int(pool)) 

381 self._mapfunc = self.pool.map 

382 self._own_pool = True 

383 else: 

384 raise RuntimeError("Number of workers specified must be -1," 

385 " an int >= 1, or an object with a 'map' method") 

386 

387 def __enter__(self): 

388 return self 

389 

390 def __del__(self): 

391 self.close() 

392 self.terminate() 

393 

394 def terminate(self): 

395 if self._own_pool: 

396 self.pool.terminate() 

397 

398 def join(self): 

399 if self._own_pool: 

400 self.pool.join() 

401 

402 def close(self): 

403 if self._own_pool: 

404 self.pool.close() 

405 

406 def __exit__(self, exc_type, exc_value, traceback): 

407 if self._own_pool: 

408 self.pool.close() 

409 self.pool.terminate() 

410 

411 def __call__(self, func, iterable): 

412 # only accept one iterable because that's all Pool.map accepts 

413 try: 

414 return self._mapfunc(func, iterable) 

415 except TypeError: 

416 # wrong number of arguments 

417 raise TypeError("The map-like callable must be of the" 

418 " form f(func, iterable)") 

419 

420 

421def rng_integers(gen, low, high=None, size=None, dtype='int64', 

422 endpoint=False): 

423 """ 

424 Return random integers from low (inclusive) to high (exclusive), or if 

425 endpoint=True, low (inclusive) to high (inclusive). Replaces 

426 `RandomState.randint` (with endpoint=False) and 

427 `RandomState.random_integers` (with endpoint=True). 

428 

429 Return random integers from the "discrete uniform" distribution of the 

430 specified dtype. If high is None (the default), then results are from 

431 0 to low. 

432 

433 Parameters 

434 ---------- 

435 gen: {None, np.random.RandomState, np.random.Generator} 

436 Random number generator. If None, then the np.random.RandomState 

437 singleton is used. 

438 low: int or array-like of ints 

439 Lowest (signed) integers to be drawn from the distribution (unless 

440 high=None, in which case this parameter is 0 and this value is used 

441 for high). 

442 high: int or array-like of ints 

443 If provided, one above the largest (signed) integer to be drawn from 

444 the distribution (see above for behavior if high=None). If array-like, 

445 must contain integer values. 

446 size: None 

447 Output shape. If the given shape is, e.g., (m, n, k), then m * n * k 

448 samples are drawn. Default is None, in which case a single value is 

449 returned. 

450 dtype: {str, dtype}, optional 

451 Desired dtype of the result. All dtypes are determined by their name, 

452 i.e., 'int64', 'int', etc, so byteorder is not available and a specific 

453 precision may have different C types depending on the platform. 

454 The default value is np.int_. 

455 endpoint: bool, optional 

456 If True, sample from the interval [low, high] instead of the default 

457 [low, high) Defaults to False. 

458 

459 Returns 

460 ------- 

461 out: int or ndarray of ints 

462 size-shaped array of random integers from the appropriate distribution, 

463 or a single such random int if size not provided. 

464 """ 

465 if isinstance(gen, Generator): 

466 return gen.integers(low, high=high, size=size, dtype=dtype, 

467 endpoint=endpoint) 

468 else: 

469 if gen is None: 

470 # default is RandomState singleton used by np.random. 

471 gen = np.random.mtrand._rand 

472 if endpoint: 

473 # inclusive of endpoint 

474 # remember that low and high can be arrays, so don't modify in 

475 # place 

476 if high is None: 

477 return gen.randint(low + 1, size=size, dtype=dtype) 

478 if high is not None: 

479 return gen.randint(low, high=high + 1, size=size, dtype=dtype) 

480 

481 # exclusive 

482 return gen.randint(low, high=high, size=size, dtype=dtype)