Coverage for emd/cycles.py: 62%

608 statements  

« prev     ^ index     » next       coverage.py v7.6.11, created at 2025-03-08 15:44 +0000

1#!/usr/bin/python 

2 

3# vim: set expandtab ts=4 sw=4: 

4 

5""" 

6Identification and analysis of cycles in an oscillatory signal. 

7 

8Routines: 

9 get_cycle_vector 

10 get_subset_vector 

11 get_chain_vector 

12 is_good 

13 get_cycle_stat 

14 get_chain_stat 

15 phase_align 

16 normalised_waveform 

17 bin_by_phase 

18 mean_vector 

19 basis_project 

20 get_control_points 

21 get_control_point_metrics 

22 get_control_point_metrics_aug 

23 kdt_match 

24 

25Cycle Features 

26 cf_start_value 

27 cf_end_value 

28 cf_peak_sample 

29 cf_peak_value 

30 cf_trough_sample 

31 cf_trough_value 

32 cf_descending_zero_sample 

33 cf_ascending_zero_sample 

34 

35Classes 

36 Cycles 

37 

38 

39""" 

40 

41import logging 

42import re 

43import warnings 

44from functools import partial 

45 

46import numpy as np 

47 

48from . import _cycles_support, imftools, spectra 

49from ._sift_core import _find_extrema 

50from .support import (ensure_1d_with_singleton, ensure_2d, ensure_equal_dims, 

51 ensure_vector) 

52 

53logger = logging.getLogger(__name__) 

54 

55 

56################################################### 

57# CYCLE IDENTIFICATION 

58 

59 

60def get_cycle_inds(*args, **kwargs): 

61 """Depreciated function.""" 

62 msg = "WARNING: 'emd.cycles.get_cycle_inds' is deprecated and " + \ 

63 "will be removed in a future version of EMD. Please change to use " + \ 

64 "'emd.cycles.get_cycle_vector' to remove this warning and " + \ 

65 "future-proof your code" 

66 

67 warnings.warn(msg) 

68 logger.warning(msg) 

69 return get_cycle_vector(*args, **kwargs) 

70 

71 

72def get_cycle_vector(phase, return_good=True, mask=None, 

73 imf=None, phase_step=1.5 * np.pi, 

74 phase_edge=np.pi / 12, min_len=2): 

75 """Identify cycles within a instantaneous phase time-course. 

76 

77 Cycles are identified by large phase jumps and can optionally be tested to 

78 remove 'bad' cycles by criteria in Notes. 

79 

80 Parameters 

81 ---------- 

82 phase : ndarray 

83 Input vector of Instantaneous Phase values 

84 return_good : bool 

85 Boolean indicating whether 'bad' cycles should be removed (Default value = True) 

86 mask : ndarray 

87 Vector of mask values that should be ignored (Default value = None) 

88 imf : ndarray 

89 Optional array of IMFs to used for control point identification when 

90 identifying good/bad cycles (Default value = None) 

91 phase_step : scalar 

92 Minimum value in the differential of the wrapped phase to identify a 

93 cycle transition (Default value = 1.5*np.pi) 

94 phase_edge : scalar 

95 Maximum distance from 0 or 2pi for the first and last phase value in a 

96 good cycle. Only used when return_good is True 

97 (Default value = np.pi/12) 

98 

99 Returns 

100 ------- 

101 ndarray 

102 Vector of integers indexing the location of each cycle 

103 

104 Notes 

105 ----- 

106 Good cycles are those with 

107 

108 * A strictly positively increasing phase 

109 * A phase starting within phase_step of zero (ie 0 < x < phase_edge) 

110 * A phase ending within phase_step of 2pi (is 2pi-phase_edge < x < 2pi) 

111 * A set of 4 unique control points (asc-zero, peak, desc-zero & trough) 

112 

113 Good cycles can be idenfied with: 

114 

115 >>> good_cycles = emd.cycles.get_cycle_vector(phase) 

116 

117 The total number of cycles is then 

118 

119 >>> good_cycles.max() 

120 

121 Indices where good cycles is zero do not contain a valid cycle 

122 

123 >>> bad_segments = good_cycles > -1 

124 

125 A single cycle can be isolated by matching its index, eg for the 5th cycle 

126 

127 >>> cycle_5_inds = good_cycles == 5 

128 

129 """ 

130 # Preamble 

131 logger.info('STARTED: get cycle indices') 

132 if mask is not None: 

133 phase, mask = ensure_2d([phase, mask], ['phase', 'mask'], 'get_cycle_vector') 

134 ensure_equal_dims((phase, mask), ('phase', 'mask'), 'get_cycle_vector', dim=0) 

135 else: 

136 phase = ensure_2d([phase], ['phase'], 'get_cycle_vector') 

137 

138 logger.debug('computing on {0} samples over {1} IMFs '.format(phase.shape[0], 

139 phase.shape[1])) 

140 if mask is not None: 

141 logger.debug('{0} ({1}%) samples masked out'.format(mask.sum(), np.round(100*(mask.sum()/phase.shape[0]), 2))) 

142 

143 # Main body 

144 

145 if phase.max() > 2 * np.pi: 

146 print('Wrapping phase') 

147 phase = imftools.wrap_phase(phase) 

148 

149 cycles = np.zeros_like(phase, dtype=int) - 1 

150 

151 for ii in range(phase.shape[1]): 

152 

153 inds = np.where(np.abs(np.diff(phase[:, ii])) > phase_step)[0] + 1 

154 

155 # No Cycles to be found 

156 if len(inds) == 0: 

157 continue 

158 

159 # Include first and last cycles, 

160 # These are likely to be bad/incomplete in real data but we should 

161 # check anyway 

162 if inds[0] >= 1: 

163 # Add first zero, don't if first value is already zero 

164 inds = np.r_[0, inds] 

165 if inds[-1] < phase.shape[0] - 1: 

166 # Add final ind value, don't if final value is already the end of array 

167 inds = np.r_[inds, phase.shape[0] - 1] 

168 

169 count = 0 

170 for jj in range(len(inds) - 1): 

171 

172 if mask is not None: 

173 # Ignore cycle if a part of it is masked out 

174 if any(~mask[inds[jj]:inds[jj + 1]]): 

175 continue 

176 

177 cycle_phase = phase[inds[jj]:inds[jj + 1], ii] 

178 

179 if return_good: 

180 cycle_checks = is_good(cycle_phase, ret_all_checks=True, 

181 phase_edge=phase_edge, min_len=min_len) 

182 else: 

183 # Pretend eveything is ok 

184 cycle_checks = np.ones((4,), dtype=bool) 

185 

186 # Add cycle to list if the checks are good 

187 if all(cycle_checks): 

188 cycles[inds[jj]:inds[jj + 1], ii] = count 

189 count += 1 

190 

191 logger.info('found {0} cycles in IMF-{1}'.format(cycles[:, ii].max(), ii)) 

192 

193 logger.info('COMPLETED: get cycle indices') 

194 return cycles 

195 

196 

197def get_subset_vector(valids): 

198 """Get subset vector from a set of per-cycle booleans. 

199 

200 Parameters 

201 ---------- 

202 valids : boolean ndarray 

203 Array of boolean values indicating which cycles should be retained 

204 

205 Returns 

206 ------- 

207 ndarray 

208 Vector across cycles where each element contains the cycle subset ind 

209 or -1 for excluded cycles. 

210 

211 """ 

212 subset_vect = np.zeros_like(valids).astype(int) - 1 

213 count = 0 

214 for ii in range(len(valids)): 

215 if valids[ii] == 0: 

216 subset_vect[ii] = -1 

217 else: 

218 subset_vect[ii] = count 

219 count += 1 

220 return subset_vect 

221 

222 

223def get_chain_vector(subset_vect): 

224 """Get chain vector from a defined subset vector. 

225 

226 Parameters 

227 ---------- 

228 subset_vect : ndarray 

229 subset vector obtained from emd.cycles.get_subset_vector 

230 

231 Returns 

232 ------- 

233 ndarray 

234 Vector across subset where each element contains the corresponding 

235 chain index. 

236 

237 """ 

238 chain_inds = np.where(subset_vect > -1)[0] 

239 dchain_inds = np.r_[1, np.diff(chain_inds)] 

240 chainv = np.zeros_like(chain_inds)-1 

241 

242 count = 0 

243 for ii in range(len(chain_inds)): 

244 if dchain_inds[ii] == 1: 

245 chainv[ii] = count 

246 elif dchain_inds[ii] > 1: 

247 count += 1 

248 chainv[ii] = count 

249 return chainv 

250 

251 

252def get_cycle_vector_from_waveform(imf, cycle_start='peaks'): 

253 """Compute cycle locations from time domain IMF. 

254 

255 THIS IS A WORK-IN-PROGRESS WHICH IS ASSUMING LOCALLY SYMMETRICAL SIGNALS!! 

256 

257 """ 

258 imf = ensure_1d_with_singleton([imf], ['imf'], 'get_cycle_vector_from_waveform') 

259 

260 if cycle_start == 'desc': 

261 print("'desc' is Not implemented yet") 

262 raise ValueError 

263 

264 cycles = np.zeros_like(imf) 

265 for ii in range(imf.shape[1]): 

266 peak_loc, peak_mag = _find_extrema(imf[:, ii]) 

267 trough_loc, trough_mag = _find_extrema(-imf[:, ii]) 

268 trough_mag = -trough_mag 

269 

270 for jj in range(len(peak_loc)-1): 

271 if cycle_start == 'peaks': 

272 start = peak_loc[jj] 

273 cycles[peak_loc[jj]:peak_loc[jj+1], ii] = jj+1 

274 elif cycle_start == 'asc': 

275 pk = peak_loc[jj] 

276 tr_ind = np.where(trough_loc - peak_loc[jj] < 0)[0][-1] 

277 tr = trough_loc[tr_ind] 

278 if (imf[tr, ii] > 0) or (imf[pk, ii] < 0): 

279 continue 

280 start = np.where(np.diff(np.sign(imf[tr:pk, ii])) == 2)[0][0] + tr 

281 

282 pk = peak_loc[jj+1] 

283 tr_ind = np.where(trough_loc - peak_loc[jj+1] < 0)[0][-1] 

284 tr = trough_loc[tr_ind] 

285 if (imf[tr, ii] > 0) or (imf[pk, ii] < 0): 

286 continue 

287 stop = np.where(np.diff(np.sign(imf[tr:pk, ii])) == 2)[0][0] + tr 

288 

289 cycles[start:stop, ii] = jj+1 

290 elif cycle_start == 'troughs': 

291 start = trough_loc[jj] 

292 cycles[trough_loc[jj]:trough_loc[jj+1], ii] = jj+1 

293 elif cycle_start == 'desc': 

294 pass 

295 

296 return cycles.astype(int) 

297 

298 

299def is_good(phase, waveform=None, ret_all_checks=False, phase_edge=np.pi/12, mode='cycle', min_len=2): 

300 """Run a set of phase checks to check if a cycle is 'good' or 'bad'. 

301 

302 This implements the checks defined in [1]_ and and one additional length 

303 check. Cycles meeting these criterial have a good chance of providing an 

304 interpretable instantaneous frequency estimate. 

305 

306 Parameters 

307 ---------- 

308 phase : ndarray 

309 Instantaneous Phase of the cycle to be checked 

310 waveform : ndarray 

311 Optional time-domain waveform to enable control point checks 

312 ret_all_checks 

313 Boolean flag indicating whether check results are returned separately 

314 phase_edge : scalar 

315 Maximum distance from 0 or 2pi for the first and last phase value in a 

316 good cycle. Only used when return_good is True 

317 (Default value = np.pi/12) 

318 min_len : int 

319 Minimum length in samples for a valid single cycle. 

320 

321 Returns 

322 ------- 

323 Boolean 

324 Flag indicating whether cycle is good (or array of booleans 

325 corresponding to each check. 

326 

327 Notes 

328 ----- 

329 A good cycle is defined as: 

330 

331 * having a phase with a strictly positive differential (i.e., no phase reversals) 

332 * starting with a phase value 0 ≤ x ≤ phase_edge 

333 * ending within 2pi − phase_edge ≤ x ≤ 2π 

334 * having four unique control points (peak, trough, ascending edge, and descending edge) 

335 * exceeding a minimum length in samples 

336 

337 References 

338 ---------- 

339 .. [1] Andrew J. Quinn, Vitor Lopes-dos-Santos, Norden Huang, Wei-Kuang 

340 Liang, Chi-Hung Juan, Jia-Rong Yeh, Anna C. Nobre, David Dupret, & Mark W. 

341 Woolrich (2021). Within-cycle instantaneous frequency profiles report 

342 oscillatory waveform dynamics. bioRxiv, 2021.04.12.439547. 

343 https://doi.org/10.1101/2021.04.12.439547 

344 

345 """ 

346 cycle_checks = np.zeros((5,), dtype=bool) 

347 

348 if mode == 'augmented': 

349 phase = np.unwrap(phase) - 2*np.pi 

350 phase_min = -np.pi/2 

351 else: 

352 phase_min = 0 

353 

354 # Check for postively increasing phase 

355 if np.all(np.diff(phase) > 0): 

356 cycle_checks[0] = True 

357 

358 # Check that start of cycle is close to 0 

359 if (phase[0] >= phase_min and phase[0] <= phase_min + phase_edge): 

360 cycle_checks[1] = True 

361 

362 # Check that end of cycle is close to pi 

363 if (phase[- 1] <= 2 * np.pi) and (phase[- 1] >= 2 * np.pi - phase_edge): 

364 cycle_checks[2] = True 

365 

366 if waveform is not None: 

367 # Check we find 5 sensible control points if imf is provided 

368 try: 

369 # Should extend this to cope with multiple peaks etc 

370 ctrl = (0, _find_extrema(waveform)[0][0], 

371 np.where(np.gradient(np.sign(waveform)) == -1)[0][0], 

372 _find_extrema(-waveform)[0][0], 

373 len(waveform)) 

374 if len(ctrl) == 5 and np.all(np.sign(np.diff(ctrl))): 

375 cycle_checks[3] = True 

376 except IndexError: 

377 # Sometimes we don't find any candidate for a control point 

378 cycle_checks[3] = False 

379 else: 

380 # No time-series so assume everything is fine 

381 cycle_checks[3] = True 

382 

383 # Check we exceed a minimum length 

384 if len(phase) > min_len: 

385 cycle_checks[4] = True 

386 

387 if ret_all_checks: 

388 return cycle_checks 

389 else: 

390 return np.all(cycle_checks) 

391 

392 

393################################################### 

394# CYCLE COMPUTATION 

395 

396def get_cycle_stat(cycles, values, mode='cycle', out=None, func=np.mean): 

397 """ 

398 Compute the average of a set of observations for each cycle. 

399 

400 Parameters 

401 ---------- 

402 cycles : ndarray 

403 array whose content index cycle locations 

404 values : ndarray 

405 array of observations to average within each cycle 

406 mode : {'compressed','full'} 

407 Flag to indicate whether to return a single value per cycle or the 

408 average values filled within a vector of the same size as values 

409 (Default value = 'compressed') 

410 func : function 

411 Function to call on the data in values for each cycle (Default 

412 np.mean). This can be any function, built-in or user defined, that 

413 processes a single vector of data returning a single value. 

414 

415 Returns 

416 ------- 

417 ndarray 

418 Array containing the cycle-averaged values 

419 

420 

421 """ 

422 # Preamble 

423 logger.info('STARTED: get_cycle_stat') 

424 

425 values = ensure_vector([values], ['values'], 'get_cycle_stat') 

426 

427 cycles = _ensure_cycle_inputs(cycles) 

428 cycles.mode = mode 

429 

430 if cycles.nsamples != values.shape[0]: 

431 raise ValueError("Mismatched inputs between 'cycles' and 'values'") 

432 

433 # Main Body 

434 

435 if mode == 'cycle': 

436 vals = _cycles_support.get_cycle_stat_from_samples(values, cycles.cycle_vect, func=func) 

437 elif mode == 'augmented': 

438 vals = _cycles_support.get_augmented_cycle_stat_from_samples(values, cycles.cycle_vect, cycles.phase, func=func) 

439 else: 

440 raise ValueError 

441 

442 if out == 'samples': 

443 vals = _cycles_support.project_cycles_to_samples(vals, cycles.cycle_vect) 

444 

445 return vals 

446 

447 

448def get_chain_stat(chains, var, func=np.mean): 

449 """ 

450 Compute a given function for observations across each chain of cycles. 

451 

452 Parameters 

453 ---------- 

454 chains : list 

455 Nested list of cycle indices. Output of emd.cycles.get_cycle_chain. 

456 var : ndarray 

457 1d array properties across all good cycles. Compressed output 

458 of emd.cycles.get_cycle_stat 

459 func : function 

460 Function to call on the data in values for each cycle (Default 

461 np.mean). This can be any function, built-in or user defined, that 

462 processes a single vector of data returning a single value. 

463 

464 Returns 

465 ------- 

466 stat : ndarray 

467 1D array of evaluated function on property var across each chain. 

468 

469 """ 

470 # Preamble 

471 logger.info('STARTED: get cycle stats') 

472 

473 logger.debug('computing stats for {0} cycles over {1} chains'.format(len(var), len(chains))) 

474 logger.debug('computing metric {0}'.format(func)) 

475 

476 # Actual computation 

477 stat = np.array([func(var[x]) for x in chains]) 

478 

479 logger.info('COMPLETED: get chain stat') 

480 return stat 

481 

482 

483def phase_align(ip, x, cycles=None, npoints=48, interp_kind='linear', min_len=1, mode='cycle'): 

484 """Align a vector of observations to a template phase time-course. 

485 

486 This implements the phase alignment method introduced in [1]_. Individual 

487 cycles must be longer than 2 samples to be phase-aligned - if a cycle 

488 cannot be phase aligned it's output will be set to np.nan. 

489 

490 Parameters 

491 ---------- 

492 ip : ndarray 

493 Input array of Instantaneous Phase values to base alignment on 

494 x : ndarray 

495 Input array of observed values to phase align 

496 cycles : ndarray (optional) 

497 Optional set of cycles within IP to use (Default value = None) 

498 npoints : int 

499 Number of points in the phase cycle to align to (Default = 48) 

500 interp_kind : {'linear','nearest','zero','slinear', 'quadratic','cubic','previous', 'next'} 

501 Type of interpolation to perform. Argument is passed onto 

502 scipy.interpolate.interp1d. (Default = 'linear') 

503 min_len : int 

504 Minimum length in samples for a cycle to be phase aligned. Shorter 

505 cycles will be returned as nans. 

506 mode : {'cycle', 'augmented'} 

507 Whether to phase align a standard 'cycle' or an 'augmented' cycle 

508 including a 5th quadrant. 

509 

510 Returns 

511 ------- 

512 ndarray : 

513 array containing the phase aligned observations 

514 

515 References 

516 ---------- 

517 .. [1] Andrew J. Quinn, Vitor Lopes-dos-Santos, Norden Huang, Wei-Kuang 

518 Liang, Chi-Hung Juan, Jia-Rong Yeh, Anna C. Nobre, David Dupret, & Mark W. 

519 Woolrich (2021). Within-cycle instantaneous frequency profiles report 

520 oscillatory waveform dynamics. bioRxiv, 2021.04.12.439547. 

521 https://doi.org/10.1101/2021.04.12.439547 

522 

523 """ 

524 # Preamble 

525 from scipy import interpolate as interp 

526 logger.info('STARTED: phase-align cycles') 

527 

528 out = ensure_vector((ip, x), ('ip', 'x'), 'phase_align') 

529 ip, x = out 

530 ensure_equal_dims((ip, x), ('ip', 'x'), 'phase_align') 

531 

532 if cycles is None: 

533 cycles = get_cycle_vector(ip, return_good=False) 

534 cycles = _ensure_cycle_inputs(cycles) 

535 

536 cycles.mode = mode 

537 

538 if cycles.nsamples != ip.shape[0]: 

539 raise ValueError("Mismatched inputs between 'cycles' and 'ip'") 

540 

541 # Main Body 

542 

543 if mode == 'cycle': 

544 phase_edges, phase_bins = spectra.define_hist_bins(0, 2 * np.pi, npoints) 

545 elif mode == 'augmented': 

546 phase_edges, phase_bins = spectra.define_hist_bins(-np.pi / 2, 2 * np.pi, npoints) 

547 

548 msg = 'aligning {0} cycles over {1} phase points with {2} interpolation' 

549 logger.debug(msg.format(cycles.niters, npoints, interp_kind)) 

550 

551 avg = np.zeros((npoints, cycles.niters)) * np.nan 

552 for cind, cycle_inds in cycles: 

553 if (cycle_inds is None) or (len(cycle_inds) <= min_len): 

554 continue 

555 phase_data = ip[cycle_inds].copy() 

556 

557 if mode == 'augmented': 

558 phase_data = np.unwrap(phase_data) - 2 * np.pi 

559 

560 x_data = x[cycle_inds] 

561 

562 f = interp.interp1d(phase_data, x_data, kind=interp_kind, 

563 bounds_error=False, fill_value='extrapolate') 

564 

565 avg[:, cind] = f(phase_bins) 

566 

567 logger.debug('{0} cycles not aligned'.format(np.isnan(avg[0, :]).sum())) 

568 

569 logger.info('COMPLETED: phase-align cycles') 

570 return avg, phase_bins 

571 

572 

573def normalised_waveform(infreq): 

574 """Compute the time-domain waveform of an phase-aligned IF profile. 

575 

576 Parameters 

577 ---------- 

578 infreq : ndarray 

579 instantaneous frequency profiles [samples x cycles] such as the output 

580 from emd.cycles.phase_align. 

581 

582 Returns 

583 ------- 

584 ndarray 

585 The normalised waveforms of the cycles in infreq 

586 ndarray 

587 A reference sinusoid of the same length as the input. 

588 

589 References 

590 ---------- 

591 .. [1] Andrew J. Quinn, Vitor Lopes-dos-Santos, Norden Huang, Wei-Kuang 

592 Liang, Chi-Hung Juan, Jia-Rong Yeh, Anna C. Nobre, David Dupret, & Mark W. 

593 Woolrich (2021). Within-cycle instantaneous frequency profiles report 

594 oscillatory waveform dynamics. bioRxiv, 2021.04.12.439547. 

595 https://doi.org/10.1101/2021.04.12.439547 

596 

597 """ 

598 infreq = ensure_2d([infreq], ['infreq'], 'normalised_waveform') 

599 nw = np.zeros((infreq.shape[0]+1, infreq.shape[1])) 

600 for ii in range(infreq.shape[1]): 

601 sr = infreq[:, ii].mean() * len(infreq[:, ii]) 

602 phase_diff = (infreq[:, ii] / sr) * (2 * np.pi) 

603 phase = np.cumsum(phase_diff, axis=0) 

604 phase = np.r_[0, phase] 

605 nw[:, ii] = np.sin(phase) 

606 sine = np.sin(np.linspace(0, 2*np.pi, len(phase))) 

607 

608 return nw, sine 

609 

610 

611def bin_by_phase(ip, x, nbins=24, weights=None, variance_metric='variance', 

612 bin_edges=None): 

613 """Compute distribution of x by phase-bins in the Instantaneous Frequency. 

614 

615 Parameters 

616 ---------- 

617 ip : ndarray 

618 Input vector of instataneous phase values 

619 x : ndarray 

620 Input array of values to be binned, first dimension much match length of 

621 IP 

622 nbins : integer 

623 number of phase bins to define (Default value = 24) 

624 weights : ndarray (optional) 

625 Optional set of linear weights to apply before averaging (Default value = None) 

626 variance_metric : {'variance','std','sem'} 

627 Flag to select whether the variance, standard deviation or standard 

628 error of the mean in computed across cycles (Default value = 'variance') 

629 bin_edges : ndarray (optional) 

630 Optional set of bin edges to override automatic bin specification (Default value = None) 

631 

632 Returns 

633 ------- 

634 avg : ndarray 

635 Vector containing the average across cycles as a function of phase 

636 var : ndarray 

637 Vector containing the selected variance metric across cycles as a 

638 function of phase 

639 bin_centres : ndarray 

640 Vector of bin centres 

641 

642 """ 

643 # Preamble 

644 ip = ensure_vector([ip], ['ip'], 'bin_by_phase') 

645 if weights is not None: 

646 weights = ensure_1d_with_singleton([weights], ['weights'], 'bin_by_phase') 

647 ensure_equal_dims((ip, x, weights), ('ip', 'x', 'weights'), 'bin_by_phase', dim=0) 

648 else: 

649 ensure_equal_dims((ip, x), ('ip', 'x'), 'bin_by_phase', dim=0) 

650 

651 # Main body 

652 

653 if bin_edges is None: 

654 bin_edges, bin_centres = spectra.define_hist_bins(0, 2 * np.pi, nbins) 

655 else: 

656 nbins = len(bin_edges) - 1 

657 bin_centres = bin_edges[:-1] + np.diff(bin_edges) / 2 

658 

659 bin_inds = np.digitize(ip, bin_edges) 

660 

661 out_dims = list((nbins, *x.shape[1:])) 

662 avg = np.zeros(out_dims) * np.nan 

663 var = np.zeros(out_dims) * np.nan 

664 for ii in range(1, nbins): 

665 inds = bin_inds == ii 

666 if weights is None: 

667 avg[ii - 1, ...] = np.average(x[inds, ...], axis=0) 

668 v = np.average( 

669 (x[inds, ...] - np.repeat(avg[None, ii - 1, ...], np.sum(inds), axis=0))**2, axis=0) 

670 else: 

671 if inds.sum() > 0: 

672 avg[ii - 1, ...] = np.average(x[inds, ...], axis=0, 

673 weights=weights[inds].dot(np.ones((1, x.shape[1])))) 

674 v = np.average((x[inds, ...] - np.repeat(avg[None, ii - 1, ...], np.sum(inds), axis=0)**2), 

675 weights=weights[inds].dot(np.ones((1, x.shape[1]))), axis=0) 

676 else: 

677 v = np.nan 

678 

679 if variance_metric == 'variance': 

680 var[ii - 1, ...] = v 

681 elif variance_metric == 'std': 

682 var[ii - 1, ...] = np.sqrt(v) 

683 elif variance_metric == 'sem': 

684 var[ii - 1, ...] = np.sqrt(v) / np.repeat(np.sqrt(inds.sum() 

685 [None, ...]), x.shape[0], axis=0) 

686 

687 return avg, var, bin_centres 

688 

689 

690def mean_vector(IP, X, mask=None): 

691 """Compute the mean vector of a set of values wrapped around the unit circle. 

692 

693 Parameters 

694 ---------- 

695 IP : ndarray 

696 Instantaneous Phase values 

697 X : ndarray 

698 Observations corresponding to IP values 

699 mask : 

700 (Default value = None) 

701 

702 Returns 

703 ------- 

704 mv : ndarray 

705 Set of mean vectors 

706 

707 """ 

708 phi = np.cos(IP) + 1j * np.sin(IP) 

709 mv = phi[:, None] * X 

710 return mv.mean(axis=0) 

711 

712 

713def basis_project(X, ncomps=1, ret_basis=False): 

714 """Express a set of signals in a simple sine-cosine basis set. 

715 

716 Parameters 

717 ---------- 

718 IP : ndarray 

719 Instantaneous Phase values 

720 X : ndarray 

721 Observations corresponding to IP values 

722 ncomps : int 

723 Number of sine-cosine pairs to express signal in (default=1) 

724 ret_basis : bool 

725 Flag indicating whether to return basis set (default=False) 

726 

727 Returns 

728 ------- 

729 basis : ndarray 

730 Set of values in basis dimensions 

731 

732 """ 

733 nsamples = X.shape[0] 

734 basis = np.c_[np.cos(np.linspace(0, 2 * np.pi, nsamples)), 

735 np.sin(np.linspace(0, 2 * np.pi, nsamples))] 

736 

737 if ncomps > 1: 

738 for ii in range(1, ncomps + 1): 

739 basis = np.c_[basis, 

740 np.cos(np.linspace(0, 2 * (ii + 1) * np.pi, nsamples)), 

741 np.sin(np.linspace(0, 2 * (ii + 1) * np.pi, nsamples))] 

742 basis = basis.T 

743 

744 if ret_basis: 

745 return basis.dot(X), basis 

746 else: 

747 return basis.dot(X) 

748 

749 

750################################################### 

751# CONTROL POINT FEATURES 

752 

753 

754def get_control_points(x, cycles, interp=False, mode='cycle'): 

755 """Identify sets of control points from identified cycles. 

756 

757 The control points are the ascending zero, peak, descending zero & trough. 

758 

759 Parameters 

760 ---------- 

761 x : ndarray 

762 Input array of oscillatory data 

763 good_cycles : ndarray 

764 array whose content index cycle locations 

765 

766 Returns 

767 ------- 

768 ndarray 

769 The control points for each cycle in x 

770 

771 """ 

772 if isinstance(cycles, np.ndarray) and mode == 'augmented': 

773 raise ValueError 

774 

775 # Preamble 

776 x = ensure_vector([x], ['x'], 'get_control_points') 

777 cycles = _ensure_cycle_inputs(cycles) 

778 if mode == 'augmented': 

779 cycles.mode = 'augmented' 

780 

781 if cycles.nsamples != x.shape[0]: 

782 raise ValueError("Mismatched inputs between 'cycles' and 'values'") 

783 

784 # Main Body 

785 

786 ctrl = list() 

787 for cind, cycle_inds in cycles: 

788 if (cycle_inds is None) or (len(cycle_inds) < 5): 

789 # We need at least 5 samples to compute control points... 

790 if mode == 'augmented': 

791 ctrl.append((np.nan, np.nan, np.nan, np.nan, np.nan, np.nan)) 

792 else: 

793 ctrl.append((np.nan, np.nan, np.nan, np.nan, np.nan)) 

794 continue 

795 

796 cycle = x[cycle_inds] 

797 

798 if mode == 'augmented': 

799 asc = cf_ascending_zero_sample(cycle, interp=interp) 

800 else: 

801 asc = None 

802 

803 pk = cf_peak_sample(cycle, interp=interp) 

804 

805 desc = cf_descending_zero_sample(cycle, interp=interp) 

806 

807 tr = cf_trough_sample(cycle, interp=interp) 

808 

809 # Append to list 

810 if mode == 'cycle': 

811 ctrl.append((0, pk, desc, tr, len(cycle)-1)) 

812 elif mode == 'augmented': 

813 ctrl.append((0, asc, pk, desc, tr, len(cycle)-1)) 

814 

815 # Return as array 

816 ctrl = np.array(ctrl) 

817 if np.any(ctrl == None): # noqa: E711 

818 ctrl[ctrl == None] = np.nan # noqa: E711 

819 

820 return ctrl 

821 

822 

823def get_control_point_metrics(ctrl, normalise=True): 

824 """Compute shape ratios from control points.""" 

825 # Peak to trough ratio 

826 p2t = (ctrl[:, 2] - (ctrl[:, 4]-ctrl[:, 2])) 

827 # Ascending to Descending ratio 

828 a2d = (ctrl[:, 1]+(ctrl[:, 4]-ctrl[:, 3])) - (ctrl[:, 3]-ctrl[:, 1]) 

829 

830 if normalise: 

831 p2t = p2t / ctrl[:, 4] 

832 a2d = a2d / ctrl[:, 4] 

833 

834 return p2t, a2d 

835 

836 

837def get_control_point_metrics_aug(ctrl): 

838 """Compute shape ratios from augmented cycle control points. 

839 

840 inputs are (start, asc, peak, desc, trough, end) 

841 

842 """ 

843 # Peak to trough ratio ( P / P+T ) 

844 p2t = (ctrl[:, 3] - ctrl[:, 1]) / (ctrl[:, 5]-ctrl[:, 1]) 

845 # Ascending to Descending ratio ( A / A+D ) 

846 a2d = ctrl[:, 2] / ctrl[:, 4] 

847 

848 return p2t, a2d 

849 

850 

851################################################### 

852# FEATURE MATCHING 

853 

854 

855def kdt_match(x, y, K=15, distance_upper_bound=np.inf): 

856 """Find unique nearest-neighbours between two n-dimensional feature sets. 

857 

858 Useful for matching two sets of cycles on one or more features (ie 

859 amplitude and average frequency). 

860 

861 Rows in x are matched to rows in y. As such - it is good to have (many) 

862 more rows in y than x if possible. 

863 

864 This uses a k-dimensional tree to query for the K nearest neighbours and 

865 returns the closest unique neighbour. If no unique match is found - the row 

866 is not returned. Increasing K will find more matches but allow matches 

867 between more distant observations. 

868 

869 Not advisable for use with more than a handful of features. 

870 

871 Parameters 

872 ---------- 

873 x : ndarray 

874 [ num observations x num features ] array to match to 

875 y : ndarray 

876 [ num observations x num features ] array of potential matches 

877 K : int 

878 number of potential nearest-neigbours to query 

879 

880 Returns 

881 ------- 

882 ndarray 

883 indices of matched observations in x 

884 ndarray 

885 indices of matched observations in y 

886 

887 """ 

888 if x.ndim == 1: 

889 x = x[:, None] 

890 if y.ndim == 1: 

891 y = y[:, None] 

892 

893 # 

894 logger.info('Starting KD-Tree Match') 

895 msg = 'Matching {0} features from y ({1} observations) to x ({2} observations)' 

896 logger.info(msg.format(x.shape[1], y.shape[0], x.shape[0])) 

897 logger.debug('K: {0}, distance_upper_bound: {1}'.format(K, distance_upper_bound)) 

898 

899 # Initialise Tree and find nearest neighbours 

900 from scipy import spatial 

901 kdt = spatial.cKDTree(y) 

902 D, inds = kdt.query(x, k=K, distance_upper_bound=distance_upper_bound) 

903 

904 II = np.zeros_like(inds) 

905 selected = [] 

906 for ii in range(K): 

907 # Find unique values and their indices in this column 

908 uni, uni_inds = _unique_inds(inds[:, ii]) 

909 # Get index of lowest distance match amongst occurrences of each unique value 

910 ix = [np.argmin(D[uni_inds[jj], ii]) for jj in range(len(uni))] 

911 # Map closest match index to full column index 

912 closest_uni_inds = [uni_inds[jj][ix[jj]] for jj in range(len(uni))] 

913 # Remove duplicates and -1s (-1 indicates distance to neighbour is 

914 # above threshold) 

915 uni = uni[(uni != np.inf)] 

916 # Remove previously selected 

917 bo = np.array([u in selected for u in uni]) 

918 uni = uni[bo == False] # noqa: E712 

919 # Find indices of matches between uniques and values in col 

920 uni_matches = np.zeros((inds.shape[0],)) 

921 uni_matches[closest_uni_inds] = np.sum(inds[closest_uni_inds, ii, None] == uni, axis=1) 

922 # Remove matches which are selected in previous columns 

923 uni_matches[II[:, :ii].sum(axis=1) > 0] = 0 

924 # Mark remaining matches with 1s in this col 

925 II[np.where(uni_matches)[0], ii] = 1 

926 selected.extend(inds[np.where(uni_matches)[0], ii]) 

927 

928 msg = '{0} Matches in layer {1}' 

929 logger.debug(msg.format(np.sum(uni_matches), ii)) 

930 

931 # Find column index of left-most choice per row (ie closest unique neighbour) 

932 winner = np.argmax(II, axis=1) 

933 # Find row index of winner 

934 final = np.zeros((II.shape[0],), dtype=int) 

935 for ii in range(II.shape[0]): 

936 if (np.sum(II[ii, :]) == 1) and (winner[ii] < y.shape[0]) and \ 

937 (inds[ii, winner[ii]] < y.shape[0]): 

938 final[ii] = inds[ii, winner[ii]] 

939 else: 

940 final[ii] = -1 # No good match 

941 

942 # Remove failed matches 

943 uni, cnt = np.unique(final, return_counts=True) 

944 x_inds = np.where(final > -1)[0] 

945 y_inds = final[x_inds] 

946 

947 # 

948 logger.info('Returning {0} matched observations'.format(x_inds.shape[0])) 

949 

950 return x_inds, y_inds 

951 

952 

953def _unique_inds(ar): 

954 """Find the unique elements of an array, ignoring shape. 

955 

956 Adapted from numpy.lib.arraysetops._unique1d Original function only returns 

957 index of first occurrence of unique value 

958 

959 """ 

960 ar = np.asanyarray(ar).flatten() 

961 ar.sort() 

962 aux = ar 

963 

964 mask = np.empty(aux.shape, dtype=np.bool_) 

965 mask[:1] = True 

966 mask[1:] = aux[1:] != aux[:-1] 

967 

968 ar_inds = [np.where(ar == ii)[0] for ii in ar[mask]] 

969 

970 return ar[mask], ar_inds 

971 

972 

973################################################### 

974# CYCLE FEATURE FUNCS 

975 

976def cf_start_value(x): 

977 """Return first value in a cycle.""" 

978 return x[0] 

979 

980 

981def cf_end_value(x): 

982 """Return last value in a cycle.""" 

983 return x[-1] 

984 

985 

986def cf_peak_sample(x, interp=True): 

987 """Compute index of peak in a single cycle.""" 

988 locs, pks = _find_extrema(x, parabolic_extrema=interp) 

989 if len(pks) == 0: 

990 return None 

991 else: 

992 return locs[np.argmax(pks)] 

993 

994 

995def cf_peak_value(x, interp=True): 

996 """Compute value at peak in a single cycle.""" 

997 locs, pks = _find_extrema(x, parabolic_extrema=interp) 

998 if len(pks) == 0: 

999 return None 

1000 else: 

1001 return pks[np.argmax(pks)] 

1002 

1003 

1004def cf_trough_sample(x, interp=True): 

1005 """Compute index of trough in a single cycle.""" 

1006 locs, trs = _find_extrema(-x, parabolic_extrema=interp) 

1007 trs = -trs 

1008 if len(trs) == 0: 

1009 return None 

1010 else: 

1011 return locs[np.argmin(trs)] 

1012 

1013 

1014def cf_trough_value(x, interp=True): 

1015 """Compute value at trough in a single cycle.""" 

1016 locs, trs = _find_extrema(-x, parabolic_extrema=interp) 

1017 trs = -trs 

1018 if len(trs) == 0: 

1019 return None 

1020 else: 

1021 

1022 return trs[np.argmin(trs)] 

1023 

1024 

1025def cf_descending_zero_sample(x, interp=True): 

1026 """Compute index of descending zero-crossing in a single cycle.""" 

1027 desc = np.where(np.diff(np.sign(x)) == -2)[0] 

1028 if len(desc) == 0: 

1029 return None 

1030 else: 

1031 desc = desc[0] 

1032 if interp: 

1033 interp_ind = np.argmin(np.abs(np.linspace(x[desc], x[desc+1], 1000))) 

1034 desc = desc + np.linspace(0, 1, 1000)[interp_ind] 

1035 return desc 

1036 

1037 

1038def cf_ascending_zero_sample(x, interp=True): 

1039 """Compute index of ascending zero-crossing in a single cycle.""" 

1040 asc = np.where(np.diff(np.sign(x)) == 2)[0] 

1041 if len(asc) == 0: 

1042 return None 

1043 else: 

1044 asc = asc[0] 

1045 if interp: 

1046 interp_ind = np.argmin(np.abs(np.linspace(x[asc], x[asc+1], 1000))) 

1047 asc = asc + np.linspace(0, 1, 1000)[interp_ind] 

1048 return asc 

1049 

1050 

1051################################################### 

1052# ITERATING OVER CYCLES 

1053 

1054 

1055def _ensure_cycle_inputs(invar): 

1056 """Take a variable and return a valid iterable cycles class if possible.""" 

1057 if isinstance(invar, np.ndarray): 

1058 # Assume we have a cycles vector 

1059 invar = ensure_vector([invar], ['cycles'], '_check_cycle_inputs') 

1060 return IterateCycles(cycle_vect=invar) 

1061 elif isinstance(invar, Cycles): 

1062 return invar.iterate() 

1063 elif isinstance(invar, IterateCycles): 

1064 return invar 

1065 else: 

1066 raise ValueError("'cycles' input not recognised, must be either a cycle-vector or Cycles class") 

1067 

1068 

1069class IterateCycles: 

1070 """Iterator class to loop through cycles in a Cycles object.""" 

1071 

1072 def __init__(self, iter_through='cycles', mode='cycle', valids=None, 

1073 cycle_vect=None, subset_vect=None, chain_vect=None, phase=None): 

1074 """Iterate through sets of cycles.""" 

1075 self.cycle_vect = cycle_vect 

1076 self.subset_vect = subset_vect 

1077 self.chain_vect = chain_vect 

1078 self.phase = phase 

1079 self.valids = valids 

1080 

1081 self.mode = mode 

1082 if valids is None: 

1083 self.iter_through = iter_through 

1084 else: 

1085 self.iter_through = 'valids' 

1086 

1087 if self.cycle_vect is not None: 

1088 self.ncycles = cycle_vect.max() + 1 

1089 self.nsamples = cycle_vect.shape[0] 

1090 if self.subset_vect is not None: 

1091 self.nsubset = subset_vect.max() + 1 

1092 if self.chain_vect is not None: 

1093 self.nchain = chain_vect.max() + 1 

1094 

1095 @property 

1096 def niters(self): 

1097 """Return the number of cycles to be iterated through.""" 

1098 if self.iter_through == 'cycles': 

1099 return self.cycle_vect.max() + 1 

1100 elif self.iter_through == 'valids': 

1101 return self.valids.sum() + 1 

1102 elif self.iter_through == 'subset': 

1103 return self.subset_vect.max() + 1 

1104 elif self.iter_through == 'chains': 

1105 return self.chain_vect.max() + 1 

1106 

1107 def __iter__(self): 

1108 """Iterate through cycles.""" 

1109 if self.iter_through == 'cycles': 

1110 return self.iterate_cycles() 

1111 elif self.iter_through == 'valids': 

1112 return self.iterate_valids() 

1113 elif self.iter_through == 'subset': 

1114 return self.iterate_subset() 

1115 elif self.iter_through == 'chains': 

1116 return self.iterate_chains() 

1117 else: 

1118 raise ValueError 

1119 

1120 def iterate_cycles(self): 

1121 """Iterate through all cycles.""" 

1122 for ii in range(self.ncycles): 

1123 if self.mode == 'cycle': 

1124 inds = _cycles_support.map_cycle_to_samples(self.cycle_vect, ii) 

1125 yield ii, inds 

1126 elif self.mode == 'augmented': 

1127 inds = _cycles_support.map_cycle_to_samples_augmented(self.cycle_vect, ii, self.phase) 

1128 yield ii, inds 

1129 else: 

1130 raise ValueError 

1131 

1132 def iterate_valids(self): 

1133 """Iterate through a custom set of matching cycles.""" 

1134 for idx, ii in enumerate(np.where(self.valids)[0]): 

1135 if self.mode == 'cycle': 

1136 inds = _cycles_support.map_cycle_to_samples(self.cycle_vect, ii) 

1137 yield idx, inds 

1138 elif self.mode == 'augmented': 

1139 inds = _cycles_support.map_cycle_to_samples_augmented(self.cycle_vect, ii, self.phase) 

1140 if inds is None: 

1141 continue 

1142 yield idx, inds 

1143 else: 

1144 raise ValueError 

1145 

1146 def iterate_subset(self): 

1147 """Iterate through the fixed subset of cycles.""" 

1148 for ii in range(self.nsubset): 

1149 if self.mode == 'cycle': 

1150 inds = _cycles_support.map_subset_to_sample(self.subset_vect, self.cycle_vect, ii) 

1151 yield ii, inds 

1152 elif self.mode == 'augmented': 

1153 inds = _cycles_support.map_subset_to_sample_augmented(self.subset_vect, self.cycle_vect, ii, self.phase) 

1154 yield ii, inds 

1155 else: 

1156 raise ValueError 

1157 

1158 def iterate_chains(self): 

1159 """Iterate through all chains.""" 

1160 for ii in range(self.nchain): 

1161 inds = _cycles_support.map_chain_to_samples(self.chain_vect, self.subset_vect, self.cycle_vect, ii) 

1162 yield ii, inds 

1163 

1164################################################### 

1165# THE CYCLES CLASS 

1166 

1167 

1168class Cycles: 

1169 """Find, store and analyse single cycles [1]_. 

1170 

1171 References 

1172 ---------- 

1173 .. [1] Andrew J. Quinn, Vitor Lopes-dos-Santos, Norden Huang, Wei-Kuang 

1174 Liang, Chi-Hung Juan, Jia-Rong Yeh, Anna C. Nobre, David Dupret, & Mark W. 

1175 Woolrich (2021). Within-cycle instantaneous frequency profiles report 

1176 oscillatory waveform dynamics. bioRxiv, 2021.04.12.439547. 

1177 https://doi.org/10.1101/2021.04.12.439547 

1178 

1179 """ 

1180 

1181 def __init__(self, IP, phase_step=1.5 * np.pi, phase_edge=np.pi/12, min_len=2, 

1182 compute_timings=False, mode='cycle', use_cache=True): 

1183 """Class storing and manipulating singl cycles.""" 

1184 logger.info('Initialising Cycles') 

1185 self.phase = IP 

1186 self.phase_step = phase_step 

1187 self.phase_edge = phase_edge 

1188 

1189 self.phase = ensure_vector([IP], ['IP'], 'Cycles') 

1190 self.cycle_vect = get_cycle_vector(self.phase, return_good=False, 

1191 phase_step=phase_step, phase_edge=phase_edge) 

1192 self.ncycles = self.cycle_vect.max() + 1 

1193 self.nsamples = self.phase.shape[0] 

1194 logger.debug('{0} cycles identified (avg len {1} samples)'.format(self.ncycles, self.nsamples/self.ncycles)) 

1195 

1196 if use_cache: 

1197 logger.debug('Populating slice cache') 

1198 self._slice_cache = _cycles_support.make_slice_cache(self.cycle_vect) 

1199 self._slice_cache_aug = _cycles_support.make_aug_slice_cache(self._slice_cache, self.phase) 

1200 else: 

1201 self._slice_cache = None 

1202 self._slice_cache_aug = None 

1203 

1204 self.subset_vect = None 

1205 self.chain_vect = None 

1206 self.mask_conditions = None 

1207 

1208 self.metrics = dict() 

1209 good_func = partial(is_good, phase_edge=phase_edge, min_len=min_len) 

1210 self.compute_cycle_metric('is_good', self.phase, good_func, dtype=int) 

1211 if compute_timings: 

1212 self.compute_cycle_timings() 

1213 

1214 def __repr__(self): 

1215 """Print a short summary.""" 

1216 if self.subset_vect is None: 

1217 return "{0} ({1} cycles {2} metrics) ".format(type(self), 

1218 self.ncycles, 

1219 len(self.metrics.keys())) 

1220 else: 

1221 msg = "{0} ({1} cycles {2} subset {3} chains - {4} metrics) " 

1222 return msg.format(type(self), 

1223 self.ncycles, 

1224 self.subset_vect.max()+1, 

1225 self.chain_vect.max(), 

1226 len(self.metrics.keys())) 

1227 

1228 # ---------------------- 

1229 

1230 def __iter__(self): 

1231 """Iterate through all cycles.""" 

1232 return self.iterate().__iter__() 

1233 

1234 def iterate(self, through='cycles', conditions=None, mode='cycle'): 

1235 """Iterate through some or all cycles.""" 

1236 if conditions is not None: 

1237 valids = self.get_matching_cycles(conditions) 

1238 else: 

1239 valids = None 

1240 

1241 looper = IterateCycles(iter_through=through, mode=mode, valids=valids, 

1242 cycle_vect=self.cycle_vect, subset_vect=self.subset_vect, 

1243 chain_vect=self.chain_vect, phase=self.phase) 

1244 return looper 

1245 

1246 # ---------------------- 

1247 

1248 def get_inds_of_cycle(self, ii, mode='cycle'): 

1249 """Find indices of specified cycle.""" 

1250 if mode == 'cycle': 

1251 inds = _cycles_support.map_cycle_to_samples(self.cycle_vect, ii) 

1252 return inds 

1253 elif mode == 'augmented': 

1254 inds = _cycles_support.map_cycle_to_samples_augmented(self.cycle_vect, ii, self.phase) 

1255 return inds 

1256 

1257 def get_cycle_vector(self, ii, mode='cycle'): 

1258 """Create cycle-vector representation of cycle timings.""" 

1259 if mode == 'cycle': 

1260 return _cycles_support.map_cycle_to_samples(self.cycle_vect, ii) 

1261 elif mode == 'augmented': 

1262 return _cycles_support.map_cycle_to_samples_augmented(self.cycle_vect, ii, self.phase) 

1263 else: 

1264 raise ValueError 

1265 

1266 def get_metric_dataframe(self, subset=False, conditions=None): 

1267 """Return pandas dataframe containing cycle metrics.""" 

1268 import pandas as pd 

1269 d = pd.DataFrame.from_dict(self.metrics) 

1270 

1271 if subset and (conditions is not None): 

1272 raise ValueError("Please specify either 'subset=True' or a set of conditions") 

1273 elif subset: 

1274 conditions = self.mask_conditions 

1275 

1276 if conditions is not None: 

1277 inds = self.get_matching_cycles(conditions) == False # noqa: E712 

1278 d = d.drop(np.where(inds)[0]) 

1279 d = d.reset_index() 

1280 

1281 return d 

1282 

1283 def get_matching_cycles(self, conditions, ret_separate=False): 

1284 """Find subset of cycles matching specified conditions.""" 

1285 if isinstance(conditions, str): 

1286 conditions = [conditions] 

1287 

1288 out = np.zeros((len(self.metrics['is_good']), len(conditions))) 

1289 for idx, c in enumerate(conditions): 

1290 name, func, val = self._parse_condition(c) 

1291 out[:, idx] = func(self.metrics[name], val) 

1292 

1293 if ret_separate: 

1294 return out 

1295 else: 

1296 return np.all(out, axis=1) 

1297 

1298 def add_cycle_metric(self, name, cycle_vals, dtype=None): 

1299 """Add an externally computed per-cycle metric.""" 

1300 if len(cycle_vals) != self.ncycles: 

1301 msg = "Input metrics ({0}) mismatched to existing metrics ({1})" 

1302 return ValueError(msg.format(cycle_vals.shape, self.ncycles)) 

1303 

1304 if dtype is not None: 

1305 if dtype is int: 

1306 cycle_vals[np.isnan(cycle_vals)] = -1 

1307 cycle_vals = cycle_vals.astype(dtype) 

1308 

1309 self._safe_add_metric(name, cycle_vals) 

1310 

1311 def _safe_add_metric(self, name, vals): 

1312 if len(vals) != self.ncycles: 

1313 raise ValueError 

1314 self.metrics[name] = vals 

1315 

1316 # ---------------------- 

1317 

1318 def compute_position_in_chain(self): 

1319 """Compute where in a sequence a cycle occurs.""" 

1320 if self.chain_vect is None: 

1321 # No chains to analyse... do 

1322 raise ValueError 

1323 

1324 chain_pos = np.zeros_like(self.chain_vect) 

1325 for ii in range(self.chain_vect.max() + 1): 

1326 inds = np.where(self.chain_vect == ii)[0] 

1327 chain_pos[inds] = np.arange(len(inds)) 

1328 chain_pos = _cycles_support.project_subset_to_cycles(chain_pos, self.subset_vect) 

1329 chain_pos[np.isnan(chain_pos)] = -1 

1330 

1331 self.metrics['chain_position'] = chain_pos.astype(int) 

1332 

1333 def compute_cycle_metric(self, name, vals, func, dtype=None, mode='cycle'): 

1334 """Compute a statistic for all cycles. 

1335 

1336 Results are stored in the Cycle object for later use. 

1337 

1338 """ 

1339 logger.info("Computing metric '{0}' using {1} with mode '{2}'".format(name, func, mode)) 

1340 if mode == 'cycle': 

1341 if self._slice_cache is None: 

1342 vals = _cycles_support.get_cycle_stat_from_samples(vals, self.cycle_vect, func=func) 

1343 else: 

1344 vals = _cycles_support.get_slice_stat_from_samples(vals, self._slice_cache, func=func) 

1345 elif mode == 'augmented': 

1346 if self._slice_cache_aug is None: 

1347 vals = _cycles_support.get_augmented_cycle_stat_from_samples(vals, self.cycle_vect, 

1348 self.phase, func=func) 

1349 else: 

1350 vals = _cycles_support.get_slice_stat_from_samples(vals, self._slice_cache_aug, func=func) 

1351 else: 

1352 raise ValueError 

1353 

1354 if dtype is not None: 

1355 vals = vals.astype(dtype) 

1356 self.add_cycle_metric(name, vals) 

1357 

1358 def compute_chain_metric(self, name, vals, func, dtype=None): 

1359 """Compute a metric for each chain and store the result in the cycle object.""" 

1360 if self.mask_conditions is None: 

1361 raise ValueError 

1362 

1363 vals = _cycles_support.get_chain_stat_from_samples(vals, self.chain_vect, 

1364 self.subset_vect, self.cycle_vect, func=func) 

1365 vals = _cycles_support.project_chain_to_cycles(vals, self.chain_vect, self.subset_vect) 

1366 

1367 if dtype is not None: 

1368 # Can't have nans in an int array - so convert to -1 

1369 vals[np.isnan(vals)] = -1 

1370 vals = vals.astype(dtype) 

1371 

1372 self.add_cycle_metric(name, vals) 

1373 

1374 def compute_cycle_timings(self): 

1375 """Compute some standard cycle timing metrics.""" 

1376 self.compute_cycle_metric('start_sample', 

1377 np.arange(len(self.cycle_vect)), 

1378 cf_start_value, 

1379 dtype=int) 

1380 self.compute_cycle_metric('stop_sample', 

1381 np.arange(len(self.cycle_vect)), 

1382 cf_end_value, 

1383 dtype=int) 

1384 self.compute_cycle_metric('duration', 

1385 self.cycle_vect, 

1386 len, 

1387 dtype=int) 

1388 

1389 def compute_chain_timings(self): 

1390 """Compute some standard chain timing metrics.""" 

1391 self.compute_chain_metric('chain_start', np.arange(0, len(self.cycle_vect)), cf_start_value, dtype=int) 

1392 self.compute_chain_metric('chain_end', np.arange(0, len(self.cycle_vect)), cf_end_value, dtype=int) 

1393 self.compute_chain_metric('chain_len_samples', self.cycle_vect, len, dtype=int) 

1394 

1395 def _get_chain_len(x): 

1396 return len(np.unique(x)) 

1397 self.compute_chain_metric('chain_len_cycles', self.cycle_vect, _get_chain_len, dtype=int) 

1398 self.compute_position_in_chain() 

1399 

1400 def pick_cycle_subset(self, conditions): 

1401 """Set conditions to define subsets + chains. This is not reversible for the moment.""" 

1402 self.mask_conditions = conditions 

1403 

1404 valids = self.get_matching_cycles(conditions) 

1405 self.subset_vect = get_subset_vector(valids) 

1406 self.chain_vect = get_chain_vector(self.subset_vect) 

1407 

1408 vals = _cycles_support.project_chain_to_cycles(np.arange(self.chain_vect.max()+1), 

1409 self.chain_vect, self.subset_vect) 

1410 self.add_cycle_metric('chain_ind', vals, dtype=int) 

1411 

1412 # ---------------------- 

1413 

1414 def _parse_condition(self, cond): 

1415 """Parse strings defining conditional statements.""" 

1416 name = re.split(r'[=<>!]', cond)[0] 

1417 comp = cond[len(name):] 

1418 

1419 if comp[:2] == '==': 

1420 func = np.equal 

1421 elif comp[:2] == '!=': 

1422 func = np.not_equal 

1423 elif comp[:2] == '<=': 

1424 func = np.less_equal 

1425 elif comp[:2] == '>=': 

1426 func = np.greater_equal 

1427 elif comp[0] == '<': 

1428 func = np.less 

1429 elif comp[0] == '>': 

1430 func = np.greater 

1431 else: 

1432 print('Comparator not recognised!') 

1433 

1434 val = float(comp.lstrip('!=<>')) 

1435 

1436 return (name, func, val)