Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import sys 

2import copy 

3import heapq 

4import collections 

5import functools 

6 

7import numpy as np 

8 

9from scipy._lib._util import MapWrapper 

10 

11 

12class LRUDict(collections.OrderedDict): 

13 def __init__(self, max_size): 

14 self.__max_size = max_size 

15 

16 def __setitem__(self, key, value): 

17 existing_key = (key in self) 

18 super(LRUDict, self).__setitem__(key, value) 

19 if existing_key: 

20 self.move_to_end(key) 

21 elif len(self) > self.__max_size: 

22 self.popitem(last=False) 

23 

24 def update(self, other): 

25 # Not needed below 

26 raise NotImplementedError() 

27 

28 

29class SemiInfiniteFunc(object): 

30 """ 

31 Argument transform from (start, +-oo) to (0, 1) 

32 """ 

33 def __init__(self, func, start, infty): 

34 self._func = func 

35 self._start = start 

36 self._sgn = -1 if infty < 0 else 1 

37 

38 # Overflow threshold for the 1/t**2 factor 

39 self._tmin = sys.float_info.min**0.5 

40 

41 def get_t(self, x): 

42 z = self._sgn * (x - self._start) + 1 

43 if z == 0: 

44 # Can happen only if point not in range 

45 return np.inf 

46 return 1 / z 

47 

48 def __call__(self, t): 

49 if t < self._tmin: 

50 return 0.0 

51 else: 

52 x = self._start + self._sgn * (1 - t) / t 

53 f = self._func(x) 

54 return self._sgn * (f / t) / t 

55 

56 

57class DoubleInfiniteFunc(object): 

58 """ 

59 Argument transform from (-oo, oo) to (-1, 1) 

60 """ 

61 def __init__(self, func): 

62 self._func = func 

63 

64 # Overflow threshold for the 1/t**2 factor 

65 self._tmin = sys.float_info.min**0.5 

66 

67 def get_t(self, x): 

68 s = -1 if x < 0 else 1 

69 return s / (abs(x) + 1) 

70 

71 def __call__(self, t): 

72 if abs(t) < self._tmin: 

73 return 0.0 

74 else: 

75 x = (1 - abs(t)) / t 

76 f = self._func(x) 

77 return (f / t) / t 

78 

79 

80def _max_norm(x): 

81 return np.amax(abs(x)) 

82 

83 

84def _get_sizeof(obj): 

85 try: 

86 return sys.getsizeof(obj) 

87 except TypeError: 

88 # occurs on pypy 

89 if hasattr(obj, '__sizeof__'): 

90 return int(obj.__sizeof__()) 

91 return 64 

92 

93 

94class _Bunch(object): 

95 def __init__(self, **kwargs): 

96 self.__keys = kwargs.keys() 

97 self.__dict__.update(**kwargs) 

98 

99 def __repr__(self): 

100 return "_Bunch({})".format(", ".join("{}={}".format(k, repr(self.__dict__[k])) 

101 for k in self.__keys)) 

102 

103 

104def quad_vec(f, a, b, epsabs=1e-200, epsrel=1e-8, norm='2', cache_size=100e6, limit=10000, 

105 workers=1, points=None, quadrature=None, full_output=False): 

106 r"""Adaptive integration of a vector-valued function. 

107 

108 Parameters 

109 ---------- 

110 f : callable 

111 Vector-valued function f(x) to integrate. 

112 a : float 

113 Initial point. 

114 b : float 

115 Final point. 

116 epsabs : float, optional 

117 Absolute tolerance. 

118 epsrel : float, optional 

119 Relative tolerance. 

120 norm : {'max', '2'}, optional 

121 Vector norm to use for error estimation. 

122 cache_size : int, optional 

123 Number of bytes to use for memoization. 

124 workers : int or map-like callable, optional 

125 If `workers` is an integer, part of the computation is done in 

126 parallel subdivided to this many tasks (using 

127 :class:`python:multiprocessing.pool.Pool`). 

128 Supply `-1` to use all cores available to the Process. 

129 Alternatively, supply a map-like callable, such as 

130 :meth:`python:multiprocessing.pool.Pool.map` for evaluating the 

131 population in parallel. 

132 This evaluation is carried out as ``workers(func, iterable)``. 

133 points : list, optional 

134 List of additional breakpoints. 

135 quadrature : {'gk21', 'gk15', 'trapz'}, optional 

136 Quadrature rule to use on subintervals. 

137 Options: 'gk21' (Gauss-Kronrod 21-point rule), 

138 'gk15' (Gauss-Kronrod 15-point rule), 

139 'trapz' (composite trapezoid rule). 

140 Default: 'gk21' for finite intervals and 'gk15' for (semi-)infinite 

141 full_output : bool, optional 

142 Return an additional ``info`` dictionary. 

143 

144 Returns 

145 ------- 

146 res : {float, array-like} 

147 Estimate for the result 

148 err : float 

149 Error estimate for the result in the given norm 

150 info : dict 

151 Returned only when ``full_output=True``. 

152 Info dictionary. Is an object with the attributes: 

153 

154 success : bool 

155 Whether integration reached target precision. 

156 status : int 

157 Indicator for convergence, success (0), 

158 failure (1), and failure due to rounding error (2). 

159 neval : int 

160 Number of function evaluations. 

161 intervals : ndarray, shape (num_intervals, 2) 

162 Start and end points of subdivision intervals. 

163 integrals : ndarray, shape (num_intervals, ...) 

164 Integral for each interval. 

165 Note that at most ``cache_size`` values are recorded, 

166 and the array may contains *nan* for missing items. 

167 errors : ndarray, shape (num_intervals,) 

168 Estimated integration error for each interval. 

169 

170 Notes 

171 ----- 

172 The algorithm mainly follows the implementation of QUADPACK's 

173 DQAG* algorithms, implementing global error control and adaptive 

174 subdivision. 

175 

176 The algorithm here has some differences to the QUADPACK approach: 

177 

178 Instead of subdividing one interval at a time, the algorithm 

179 subdivides N intervals with largest errors at once. This enables 

180 (partial) parallelization of the integration. 

181 

182 The logic of subdividing "next largest" intervals first is then 

183 not implemented, and we rely on the above extension to avoid 

184 concentrating on "small" intervals only. 

185 

186 The Wynn epsilon table extrapolation is not used (QUADPACK uses it 

187 for infinite intervals). This is because the algorithm here is 

188 supposed to work on vector-valued functions, in an user-specified 

189 norm, and the extension of the epsilon algorithm to this case does 

190 not appear to be widely agreed. For max-norm, using elementwise 

191 Wynn epsilon could be possible, but we do not do this here with 

192 the hope that the epsilon extrapolation is mainly useful in 

193 special cases. 

194 

195 References 

196 ---------- 

197 [1] R. Piessens, E. de Doncker, QUADPACK (1983). 

198 

199 Examples 

200 -------- 

201 We can compute integrations of a vector-valued function: 

202 

203 >>> from scipy.integrate import quad_vec 

204 >>> import matplotlib.pyplot as plt 

205 >>> alpha = np.linspace(0.0, 2.0, num=30) 

206 >>> f = lambda x: x**alpha 

207 >>> x0, x1 = 0, 2 

208 >>> y, err = quad_vec(f, x0, x1) 

209 >>> plt.plot(alpha, y) 

210 >>> plt.xlabel(r"$\alpha$") 

211 >>> plt.ylabel(r"$\int_{0}^{2} x^\alpha dx$") 

212 >>> plt.show() 

213 

214 """ 

215 a = float(a) 

216 b = float(b) 

217 

218 # Use simple transformations to deal with integrals over infinite 

219 # intervals. 

220 kwargs = dict(epsabs=epsabs, 

221 epsrel=epsrel, 

222 norm=norm, 

223 cache_size=cache_size, 

224 limit=limit, 

225 workers=workers, 

226 points=points, 

227 quadrature='gk15' if quadrature is None else quadrature, 

228 full_output=full_output) 

229 if np.isfinite(a) and np.isinf(b): 

230 f2 = SemiInfiniteFunc(f, start=a, infty=b) 

231 if points is not None: 

232 kwargs['points'] = tuple(f2.get_t(xp) for xp in points) 

233 return quad_vec(f2, 0, 1, **kwargs) 

234 elif np.isfinite(b) and np.isinf(a): 

235 f2 = SemiInfiniteFunc(f, start=b, infty=a) 

236 if points is not None: 

237 kwargs['points'] = tuple(f2.get_t(xp) for xp in points) 

238 res = quad_vec(f2, 0, 1, **kwargs) 

239 return (-res[0],) + res[1:] 

240 elif np.isinf(a) and np.isinf(b): 

241 sgn = -1 if b < a else 1 

242 

243 # NB. explicitly split integral at t=0, which separates 

244 # the positive and negative sides 

245 f2 = DoubleInfiniteFunc(f) 

246 if points is not None: 

247 kwargs['points'] = (0,) + tuple(f2.get_t(xp) for xp in points) 

248 else: 

249 kwargs['points'] = (0,) 

250 

251 if a != b: 

252 res = quad_vec(f2, -1, 1, **kwargs) 

253 else: 

254 res = quad_vec(f2, 1, 1, **kwargs) 

255 

256 return (res[0]*sgn,) + res[1:] 

257 elif not (np.isfinite(a) and np.isfinite(b)): 

258 raise ValueError("invalid integration bounds a={}, b={}".format(a, b)) 

259 

260 norm_funcs = { 

261 None: _max_norm, 

262 'max': _max_norm, 

263 '2': np.linalg.norm 

264 } 

265 if callable(norm): 

266 norm_func = norm 

267 else: 

268 norm_func = norm_funcs[norm] 

269 

270 mapwrapper = MapWrapper(workers) 

271 

272 parallel_count = 128 

273 min_intervals = 2 

274 

275 try: 

276 _quadrature = {None: _quadrature_gk21, 

277 'gk21': _quadrature_gk21, 

278 'gk15': _quadrature_gk15, 

279 'trapz': _quadrature_trapz}[quadrature] 

280 except KeyError: 

281 raise ValueError("unknown quadrature {!r}".format(quadrature)) 

282 

283 # Initial interval set 

284 if points is None: 

285 initial_intervals = [(a, b)] 

286 else: 

287 prev = a 

288 initial_intervals = [] 

289 for p in sorted(points): 

290 p = float(p) 

291 if not (a < p < b) or p == prev: 

292 continue 

293 initial_intervals.append((prev, p)) 

294 prev = p 

295 initial_intervals.append((prev, b)) 

296 

297 global_integral = None 

298 global_error = None 

299 rounding_error = None 

300 interval_cache = None 

301 intervals = [] 

302 neval = 0 

303 

304 for x1, x2 in initial_intervals: 

305 ig, err, rnd = _quadrature(x1, x2, f, norm_func) 

306 neval += _quadrature.num_eval 

307 

308 if global_integral is None: 

309 if isinstance(ig, (float, complex)): 

310 # Specialize for scalars 

311 if norm_func in (_max_norm, np.linalg.norm): 

312 norm_func = abs 

313 

314 global_integral = ig 

315 global_error = float(err) 

316 rounding_error = float(rnd) 

317 

318 cache_count = cache_size // _get_sizeof(ig) 

319 interval_cache = LRUDict(cache_count) 

320 else: 

321 global_integral += ig 

322 global_error += err 

323 rounding_error += rnd 

324 

325 interval_cache[(x1, x2)] = copy.copy(ig) 

326 intervals.append((-err, x1, x2)) 

327 

328 heapq.heapify(intervals) 

329 

330 CONVERGED = 0 

331 NOT_CONVERGED = 1 

332 ROUNDING_ERROR = 2 

333 NOT_A_NUMBER = 3 

334 

335 status_msg = { 

336 CONVERGED: "Target precision reached.", 

337 NOT_CONVERGED: "Target precision not reached.", 

338 ROUNDING_ERROR: "Target precision could not be reached due to rounding error.", 

339 NOT_A_NUMBER: "Non-finite values encountered." 

340 } 

341 

342 # Process intervals 

343 with mapwrapper: 

344 ier = NOT_CONVERGED 

345 

346 while intervals and len(intervals) < limit: 

347 # Select intervals with largest errors for subdivision 

348 tol = max(epsabs, epsrel*norm_func(global_integral)) 

349 

350 to_process = [] 

351 err_sum = 0 

352 

353 for j in range(parallel_count): 

354 if not intervals: 

355 break 

356 

357 if j > 0 and err_sum > global_error - tol/8: 

358 # avoid unnecessary parallel splitting 

359 break 

360 

361 interval = heapq.heappop(intervals) 

362 

363 neg_old_err, a, b = interval 

364 old_int = interval_cache.pop((a, b), None) 

365 to_process.append(((-neg_old_err, a, b, old_int), f, norm_func, _quadrature)) 

366 err_sum += -neg_old_err 

367 

368 # Subdivide intervals 

369 for dint, derr, dround_err, subint, dneval in mapwrapper(_subdivide_interval, to_process): 

370 neval += dneval 

371 global_integral += dint 

372 global_error += derr 

373 rounding_error += dround_err 

374 for x in subint: 

375 x1, x2, ig, err = x 

376 interval_cache[(x1, x2)] = ig 

377 heapq.heappush(intervals, (-err, x1, x2)) 

378 

379 # Termination check 

380 if len(intervals) >= min_intervals: 

381 tol = max(epsabs, epsrel*norm_func(global_integral)) 

382 if global_error < tol/8: 

383 ier = CONVERGED 

384 break 

385 if global_error < rounding_error: 

386 ier = ROUNDING_ERROR 

387 break 

388 

389 if not (np.isfinite(global_error) and np.isfinite(rounding_error)): 

390 ier = NOT_A_NUMBER 

391 break 

392 

393 res = global_integral 

394 err = global_error + rounding_error 

395 

396 if full_output: 

397 res_arr = np.asarray(res) 

398 dummy = np.full(res_arr.shape, np.nan, dtype=res_arr.dtype) 

399 integrals = np.array([interval_cache.get((z[1], z[2]), dummy) 

400 for z in intervals], dtype=res_arr.dtype) 

401 errors = np.array([-z[0] for z in intervals]) 

402 intervals = np.array([[z[1], z[2]] for z in intervals]) 

403 

404 info = _Bunch(neval=neval, 

405 success=(ier == CONVERGED), 

406 status=ier, 

407 message=status_msg[ier], 

408 intervals=intervals, 

409 integrals=integrals, 

410 errors=errors) 

411 return (res, err, info) 

412 else: 

413 return (res, err) 

414 

415 

416def _subdivide_interval(args): 

417 interval, f, norm_func, _quadrature = args 

418 old_err, a, b, old_int = interval 

419 

420 c = 0.5 * (a + b) 

421 

422 # Left-hand side 

423 if getattr(_quadrature, 'cache_size', 0) > 0: 

424 f = functools.lru_cache(_quadrature.cache_size)(f) 

425 

426 s1, err1, round1 = _quadrature(a, c, f, norm_func) 

427 dneval = _quadrature.num_eval 

428 s2, err2, round2 = _quadrature(c, b, f, norm_func) 

429 dneval += _quadrature.num_eval 

430 if old_int is None: 

431 old_int, _, _ = _quadrature(a, b, f, norm_func) 

432 dneval += _quadrature.num_eval 

433 

434 if getattr(_quadrature, 'cache_size', 0) > 0: 

435 dneval = f.cache_info().misses 

436 

437 dint = s1 + s2 - old_int 

438 derr = err1 + err2 - old_err 

439 dround_err = round1 + round2 

440 

441 subintervals = ((a, c, s1, err1), (c, b, s2, err2)) 

442 return dint, derr, dround_err, subintervals, dneval 

443 

444 

445def _quadrature_trapz(x1, x2, f, norm_func): 

446 """ 

447 Composite trapezoid quadrature 

448 """ 

449 x3 = 0.5*(x1 + x2) 

450 f1 = f(x1) 

451 f2 = f(x2) 

452 f3 = f(x3) 

453 

454 s2 = 0.25 * (x2 - x1) * (f1 + 2*f3 + f2) 

455 

456 round_err = 0.25 * abs(x2 - x1) * (float(norm_func(f1)) 

457 + 2*float(norm_func(f3)) 

458 + float(norm_func(f2))) * 2e-16 

459 

460 s1 = 0.5 * (x2 - x1) * (f1 + f2) 

461 err = 1/3 * float(norm_func(s1 - s2)) 

462 return s2, err, round_err 

463 

464 

465_quadrature_trapz.cache_size = 3 * 3 

466_quadrature_trapz.num_eval = 3 

467 

468 

469def _quadrature_gk(a, b, f, norm_func, x, w, v): 

470 """ 

471 Generic Gauss-Kronrod quadrature 

472 """ 

473 

474 fv = [0.0]*len(x) 

475 

476 c = 0.5 * (a + b) 

477 h = 0.5 * (b - a) 

478 

479 # Gauss-Kronrod 

480 s_k = 0.0 

481 s_k_abs = 0.0 

482 for i in range(len(x)): 

483 ff = f(c + h*x[i]) 

484 fv[i] = ff 

485 

486 vv = v[i] 

487 

488 # \int f(x) 

489 s_k += vv * ff 

490 # \int |f(x)| 

491 s_k_abs += vv * abs(ff) 

492 

493 # Gauss 

494 s_g = 0.0 

495 for i in range(len(w)): 

496 s_g += w[i] * fv[2*i + 1] 

497 

498 # Quadrature of abs-deviation from average 

499 s_k_dabs = 0.0 

500 y0 = s_k / 2.0 

501 for i in range(len(x)): 

502 # \int |f(x) - y0| 

503 s_k_dabs += v[i] * abs(fv[i] - y0) 

504 

505 # Use similar error estimation as quadpack 

506 err = float(norm_func((s_k - s_g) * h)) 

507 dabs = float(norm_func(s_k_dabs * h)) 

508 if dabs != 0 and err != 0: 

509 err = dabs * min(1.0, (200 * err / dabs)**1.5) 

510 

511 eps = sys.float_info.epsilon 

512 round_err = float(norm_func(50 * eps * h * s_k_abs)) 

513 

514 if round_err > sys.float_info.min: 

515 err = max(err, round_err) 

516 

517 return h * s_k, err, round_err 

518 

519 

520def _quadrature_gk21(a, b, f, norm_func): 

521 """ 

522 Gauss-Kronrod 21 quadrature with error estimate 

523 """ 

524 # Gauss-Kronrod points 

525 x = (0.995657163025808080735527280689003, 

526 0.973906528517171720077964012084452, 

527 0.930157491355708226001207180059508, 

528 0.865063366688984510732096688423493, 

529 0.780817726586416897063717578345042, 

530 0.679409568299024406234327365114874, 

531 0.562757134668604683339000099272694, 

532 0.433395394129247190799265943165784, 

533 0.294392862701460198131126603103866, 

534 0.148874338981631210884826001129720, 

535 0, 

536 -0.148874338981631210884826001129720, 

537 -0.294392862701460198131126603103866, 

538 -0.433395394129247190799265943165784, 

539 -0.562757134668604683339000099272694, 

540 -0.679409568299024406234327365114874, 

541 -0.780817726586416897063717578345042, 

542 -0.865063366688984510732096688423493, 

543 -0.930157491355708226001207180059508, 

544 -0.973906528517171720077964012084452, 

545 -0.995657163025808080735527280689003) 

546 

547 # 10-point weights 

548 w = (0.066671344308688137593568809893332, 

549 0.149451349150580593145776339657697, 

550 0.219086362515982043995534934228163, 

551 0.269266719309996355091226921569469, 

552 0.295524224714752870173892994651338, 

553 0.295524224714752870173892994651338, 

554 0.269266719309996355091226921569469, 

555 0.219086362515982043995534934228163, 

556 0.149451349150580593145776339657697, 

557 0.066671344308688137593568809893332) 

558 

559 # 21-point weights 

560 v = (0.011694638867371874278064396062192, 

561 0.032558162307964727478818972459390, 

562 0.054755896574351996031381300244580, 

563 0.075039674810919952767043140916190, 

564 0.093125454583697605535065465083366, 

565 0.109387158802297641899210590325805, 

566 0.123491976262065851077958109831074, 

567 0.134709217311473325928054001771707, 

568 0.142775938577060080797094273138717, 

569 0.147739104901338491374841515972068, 

570 0.149445554002916905664936468389821, 

571 0.147739104901338491374841515972068, 

572 0.142775938577060080797094273138717, 

573 0.134709217311473325928054001771707, 

574 0.123491976262065851077958109831074, 

575 0.109387158802297641899210590325805, 

576 0.093125454583697605535065465083366, 

577 0.075039674810919952767043140916190, 

578 0.054755896574351996031381300244580, 

579 0.032558162307964727478818972459390, 

580 0.011694638867371874278064396062192) 

581 

582 return _quadrature_gk(a, b, f, norm_func, x, w, v) 

583 

584 

585_quadrature_gk21.num_eval = 21 

586 

587 

588def _quadrature_gk15(a, b, f, norm_func): 

589 """ 

590 Gauss-Kronrod 15 quadrature with error estimate 

591 """ 

592 # Gauss-Kronrod points 

593 x = (0.991455371120812639206854697526329, 

594 0.949107912342758524526189684047851, 

595 0.864864423359769072789712788640926, 

596 0.741531185599394439863864773280788, 

597 0.586087235467691130294144838258730, 

598 0.405845151377397166906606412076961, 

599 0.207784955007898467600689403773245, 

600 0.000000000000000000000000000000000, 

601 -0.207784955007898467600689403773245, 

602 -0.405845151377397166906606412076961, 

603 -0.586087235467691130294144838258730, 

604 -0.741531185599394439863864773280788, 

605 -0.864864423359769072789712788640926, 

606 -0.949107912342758524526189684047851, 

607 -0.991455371120812639206854697526329) 

608 

609 # 7-point weights 

610 w = (0.129484966168869693270611432679082, 

611 0.279705391489276667901467771423780, 

612 0.381830050505118944950369775488975, 

613 0.417959183673469387755102040816327, 

614 0.381830050505118944950369775488975, 

615 0.279705391489276667901467771423780, 

616 0.129484966168869693270611432679082) 

617 

618 # 15-point weights 

619 v = (0.022935322010529224963732008058970, 

620 0.063092092629978553290700663189204, 

621 0.104790010322250183839876322541518, 

622 0.140653259715525918745189590510238, 

623 0.169004726639267902826583426598550, 

624 0.190350578064785409913256402421014, 

625 0.204432940075298892414161999234649, 

626 0.209482141084727828012999174891714, 

627 0.204432940075298892414161999234649, 

628 0.190350578064785409913256402421014, 

629 0.169004726639267902826583426598550, 

630 0.140653259715525918745189590510238, 

631 0.104790010322250183839876322541518, 

632 0.063092092629978553290700663189204, 

633 0.022935322010529224963732008058970) 

634 

635 return _quadrature_gk(a, b, f, norm_func, x, w, v) 

636 

637 

638_quadrature_gk15.num_eval = 15