Coverage for emd/plotting.py: 5%

220 statements  

« prev     ^ index     » next       coverage.py v7.6.11, created at 2025-03-08 15:44 +0000

1#!/usr/bin/python 

2 

3# vim: set expandtab ts=4 sw=4: 

4 

5""" 

6Routines for plotting results of EMD analyses. 

7 

8Main Routines: 

9 plot_imfs 

10 plot_hilberthuang 

11 plot_holospectrum 

12 

13Utilities: 

14 _get_log_tickpos 

15 

16""" 

17 

18import logging 

19from functools import partial 

20 

21import numpy as np 

22 

23# Housekeeping for logging 

24logger = logging.getLogger(__name__) 

25 

26 

27def plot_imfs(imfs, time_vect=None, X=None, step=4, sample_rate=1, sharey=True, 

28 scale_y=None, cmap=True, fig=None, fig_args=None, ax=None, 

29 xlabel='Time (samples)', ylabel_args=None, ylabel_xoffset=-0.08, 

30 tick_params=None): 

31 """Create a quick summary plot for a set of IMFs. 

32 

33 Parameters 

34 ---------- 

35 imfs : ndarray 

36 2D array of IMFs to plot 

37 time_vect : ndarray 

38 Optional 1D array specifying time values (Default value = None) 

39 X : ndarray 

40 Original data prior to decomposition. If passed, this will be plotted 

41 on the top row rather than the sum of the IMFs. Useful for visualising 

42 incomplete sets of IMFs. 

43 step : float 

44 Scaling factor determining spacing between IMF subaxes, approximately 

45 corresponds to the z-value of the y-axis extremeties for each IMF. If 

46 there is substantial overlap between IMFs, this value can be increased 

47 to compensate. 

48 sample_rate : float 

49 Optional sample rate to determine time axis values if time_vect is not 

50 specified if time_vect is given. 

51 sharey: Boolean 

52 Flag indicating whether the y-axis should be adaptive to each mode 

53 (False) or consistent across modes (True) (Default value = True) 

54 cmap : {None,True,matplotlib colormap} 

55 Optional colourmap to use. None will plot each IMF in black and True will 

56 use the plt.cm.Dark2 colormap as default. A different colormap may also 

57 be passed in. 

58 fig : matplotlib figure instance 

59 Optional figure to make the plot in. 

60 fig_args : dict 

61 Dictionary of kwargs to pass to plt.figure, only used if 'fig' is not passed. 

62 ax : matplotlib axes instance 

63 Optional axes to make the plot in. 

64 xlabel : str 

65 Optional x-axis label. Defaults to 'Time (samples)' 

66 ylabel_args : dict 

67 Optional arguments to be passed to plt.text to create the y-axis 

68 labels. Defaults to {'ha': 'center', 'va': 'center', 'fontsize': 14}. 

69 These values remain unless explicitly overridden. 

70 ylabel_xoffset : float 

71 Optional axis offset to fine-tune the position of the y axis labels. 

72 Defaults to -0.08 and typically only needs VERY minor adjustment. 

73 tick_params : dict 

74 Optional arguments passed to plt.tick_params to style the tick labels. 

75 Defaults to {'axis': 'both', 'which': 'major', 'fontsize': 10}. 

76 These values remain unless explicitly overridden. 

77 

78 """ 

79 import matplotlib.pyplot as plt 

80 from matplotlib.colors import Colormap 

81 from scipy import stats 

82 if scale_y is not None: 

83 logger.warning("The argument 'scale_y' is depreciated and will be \ 

84 removed in a future version. Please use 'sharey' to remove this \ 

85 warning") 

86 sharey = False if scale_y is True else True 

87 

88 if time_vect is None: 

89 time_vect = np.linspace(0, imfs.shape[0]/sample_rate, imfs.shape[0]) 

90 

91 # Set y-axis label arguments 

92 ylabel_args = {} if ylabel_args is None else ylabel_args 

93 ylabel_args.setdefault('ha', 'center') 

94 ylabel_args.setdefault('va', 'center') 

95 ylabel_args.setdefault('fontsize', 14) 

96 

97 # Set axis tick label arguments 

98 tick_params = {} if tick_params is None else tick_params 

99 tick_params.setdefault('axis', 'both') 

100 tick_params.setdefault('which', 'major') 

101 tick_params.setdefault('labelsize', 10) 

102 

103 top_label = 'Summed\nIMFs' if X is None else 'Raw\nSignal' 

104 X = imfs.sum(axis=1) if X is None else X 

105 

106 order_of_magnitude = int(np.floor(np.log(X.std()))) 

107 round_scale = -order_of_magnitude if order_of_magnitude < 0 else 12 

108 

109 # Everything is z-transformed internally to make positioning in the axis 

110 # simpler. We either z-transform relative to summed signal or for each imf 

111 # in turn. Also divide per-imf scaled data by 2 to reduce overlap as 

112 # z-transform relative to full signal will naturally give smaller ranges. 

113 # 

114 # Either way - Scale based on variance of raw data - don't touch the mean. 

115 if sharey is False: 

116 def scale_func(x): 

117 return (stats.zscore(x) + x.mean()) / 2 

118 else: 

119 scale_func = partial(stats.zmap, compare=X-X.mean()) 

120 

121 if fig is None and ax is None: 

122 if fig_args is None: 

123 fig_args = {'figsize': (16, 10)} 

124 fig = plt.figure(**fig_args) 

125 plt.subplots_adjust(top=0.975, right=0.975) 

126 

127 plt.tick_params(**tick_params) 

128 

129 if ax is None: 

130 ax = plt.subplot(111) 

131 

132 if cmap is True: 

133 # Use default colormap 

134 cmap = plt.cm.Dark2 

135 cols = cmap(np.linspace(0, 1, imfs.shape[1] + 1)) 

136 elif isinstance(cmap, Colormap): 

137 # Use specified colormap 

138 cols = cmap(np.linspace(0, 1, imfs.shape[1] + 1)) 

139 else: 

140 # Use all black lines - this is overall default 

141 cols = np.array([[0, 0, 0] for ii in range(imfs.shape[1] + 1)]) 

142 

143 # Blended transform uses axis coords for X and data coords for Y 

144 import matplotlib.transforms as transforms 

145 trans = transforms.blended_transform_factory(ax.transAxes, ax.transData) 

146 

147 # Initialise tick lists 

148 yticks = [] 

149 yticklabels = [] 

150 

151 # Plot full time-series 

152 first_step = 0 

153 ax.plot((time_vect[0], time_vect[0]), ((first_step)-1.5, (first_step)+1.5), 'k') 

154 ax.plot(time_vect, np.zeros_like(time_vect) + first_step, 

155 lw=0.5, color=[0.8, 0.8, 0.8]) 

156 ax.plot(time_vect, scale_func(X)+first_step, 'k') 

157 ax.text(ylabel_xoffset, first_step, top_label, 

158 transform=trans, **ylabel_args) 

159 

160 # Set y-axis for full time-series 

161 lim = np.round(1.5 * X.std(), 2) 

162 yticks_imf = _get_sensible_ticks(lim) 

163 ytickpos = yticks_imf / X.std() 

164 

165 yticks.extend(first_step+ytickpos) 

166 yticklabels.extend(np.round(yticks_imf, round_scale)) 

167 

168 # Main IMF loop 

169 for ii in range(imfs.shape[1]): 

170 this_step = (ii+1)*step 

171 

172 # Plot IMF and axis lines 

173 ax.plot(time_vect, np.zeros_like(time_vect) - this_step, 

174 lw=0.5, color=[0.8, 0.8, 0.8]) 

175 ax.plot((time_vect[0], time_vect[0]), (-this_step-1.5, -this_step+1.5), 'k') 

176 ax.plot(time_vect, scale_func(imfs[:, ii]) - this_step, color=cols[ii, :]) 

177 

178 # Compute ticks 

179 if scale_y: 

180 lim = 1.5 * imfs[:, ii].std() 

181 yticks_imf = _get_sensible_ticks(lim) 

182 ytickpos = yticks_imf / imfs[:, ii].std() 

183 else: 

184 lim = 1.5 * X.std() 

185 yticks_imf = _get_sensible_ticks(lim) 

186 ytickpos = yticks_imf / X.std() 

187 

188 yticks.extend(-this_step+ytickpos) 

189 yticklabels.extend(np.round(yticks_imf, round_scale)) 

190 

191 # Add label 

192 ax.text(ylabel_xoffset, -this_step, 'IMF-{}'.format(ii+1), 

193 transform=trans, **ylabel_args) 

194 

195 # Hide unwanted spines 

196 for tag in ['left', 'top', 'right']: 

197 ax.spines[tag].set_visible(False) 

198 ymax = np.max(scale_func(X)+step/2) 

199 

200 # Set axis limits 

201 ax.set_ylim(np.min(yticks)-1, ymax) 

202 ax.set_xlim(time_vect[0], time_vect[-1]) 

203 

204 # Set axis ticks 

205 ax.set_yticks(yticks) 

206 ax.set_yticklabels(yticklabels) 

207 

208 ax.set_xlabel(xlabel, fontsize=ylabel_args.get('fontsize', 14)) 

209 

210 return ax 

211 

212 

213def _get_sensible_ticks(lim, nbins=3): 

214 """Return sensibly rounded tick positions based on a plotting range. 

215 

216 Based on code in matplotlib.ticker 

217 Assuming symmetrical axes and 3 ticks for the moment 

218 

219 """ 

220 from matplotlib import ticker 

221 scale, offset = ticker.scale_range(-lim, lim) 

222 if lim/scale > 0.5: 

223 scale = scale / 2 

224 edge = ticker._Edge_integer(scale, offset) 

225 low = edge.ge(-lim) 

226 high = edge.le(lim) 

227 

228 ticks = np.linspace(low, high, nbins) * scale 

229 

230 return ticks 

231 

232 

233def plot_imfs_depreciated(imfs, time_vect=None, sample_rate=1, scale_y=False, freqs=None, cmap=None, fig=None): 

234 """Create a quick summary plot for a set of IMFs. 

235 

236 Parameters 

237 ---------- 

238 imfs : ndarray 

239 2D array of IMFs to plot 

240 time_vect : ndarray 

241 Optional 1D array specifying time values (Default value = None) 

242 sample_rate : float 

243 Optional sample rate to determine time axis values if time_vect is not 

244 specified if time_vect is given. 

245 scale_y : Boolean 

246 Flag indicating whether the y-axis should be adative to each mode 

247 (False) or consistent across modes (True) (Default value = False) 

248 freqs : array_like 

249 Optional vector of frequencies for each IMF 

250 cmap : {None,True,matplotlib colormap} 

251 Optional colourmap to use. None will plot each IMF in black and True will 

252 use the plt.cm.Dark2 colormap as default. A different colormap may also 

253 be passed in. 

254 

255 """ 

256 import matplotlib.pyplot as plt 

257 from matplotlib.colors import Colormap 

258 nplots = imfs.shape[1] + 1 

259 if time_vect is None: 

260 time_vect = np.linspace(0, imfs.shape[0]/sample_rate, imfs.shape[0]) 

261 

262 mx = np.abs(imfs).max() 

263 mx_sig = np.abs(imfs.sum(axis=1)).max() 

264 

265 if fig is None: 

266 fig = plt.figure() 

267 

268 ax = fig.add_subplot(nplots, 1, 1) 

269 if scale_y: 

270 ax.yaxis.get_major_locator().set_params(integer=True) 

271 for tag in ['top', 'right', 'bottom']: 

272 ax.spines[tag].set_visible(False) 

273 ax.plot((time_vect[0], time_vect[-1]), (0, 0), color=[.5, .5, .5]) 

274 ax.plot(time_vect, imfs.sum(axis=1), 'k') 

275 ax.tick_params(axis='x', labelbottom=False) 

276 ax.set_xlim(time_vect[0], time_vect[-1]) 

277 ax.set_ylim(-mx_sig * 1.1, mx_sig * 1.1) 

278 ax.set_ylabel('Signal', rotation=0, labelpad=10) 

279 

280 if cmap is True: 

281 # Use default colormap 

282 cmap = plt.cm.Dark2 

283 cols = cmap(np.linspace(0, 1, imfs.shape[1] + 1)) 

284 elif isinstance(cmap, Colormap): 

285 # Use specified colormap 

286 cols = cmap(np.linspace(0, 1, imfs.shape[1] + 1)) 

287 else: 

288 # Use all black lines - this is overall default 

289 cols = np.array([[0, 0, 0] for ii in range(imfs.shape[1] + 1)]) 

290 

291 for ii in range(1, nplots): 

292 ax = fig.add_subplot(nplots, 1, ii + 1) 

293 for tag in ['top', 'right', 'bottom']: 

294 ax.spines[tag].set_visible(False) 

295 ax.plot((time_vect[0], time_vect[-1]), (0, 0), color=[.5, .5, .5]) 

296 ax.plot(time_vect, imfs[:, ii - 1], color=cols[ii, :]) 

297 ax.set_xlim(time_vect[0], time_vect[-1]) 

298 if scale_y: 

299 ax.set_ylim(-mx * 1.1, mx * 1.1) 

300 ax.yaxis.get_major_locator().set_params(integer=True) 

301 ax.set_ylabel('IMF {0}'.format(ii), rotation=0, labelpad=10) 

302 

303 if ii < nplots - 1: 

304 ax.tick_params(axis='x', labelbottom=False) 

305 else: 

306 ax.set_xlabel('Time') 

307 if freqs is not None: 

308 ax.set_title(freqs[ii - 1], fontsize=8) 

309 

310 fig.subplots_adjust(top=.95, bottom=.1, left=.2, right=.99) 

311 

312 

313def plot_hilberthuang(hht, time_vect, freq_vect, 

314 time_lims=None, freq_lims=None, log_y=False, 

315 vmin=0, vmax=None, 

316 fig=None, ax=None, cmap='hot_r'): 

317 """Create a quick summary plot for a Hilbert-Huang Transform. 

318 

319 Parameters 

320 ---------- 

321 hht : 2d array 

322 Hilbert-Huang spectrum to be plotted - output from emd.spectra.hilberthuang 

323 time_vect : vector 

324 Vector of time samples 

325 freq_vect : vector 

326 Vector of frequency bins 

327 time_lims : optional tuple or list (start_val, end_val) 

328 Optional time-limits to zoom in time on the x-axis 

329 freq_lims : optional tuple or list (start_val, end_val) 

330 Optional time-limits to zoom in frequency on the y-axis 

331 fig : optional figure handle 

332 Figure to plot inside 

333 ax : optional axis handle 

334 Axis to plot inside 

335 cmap : optional str or matplotlib.cm 

336 Colormap specification 

337 

338 Returns 

339 ------- 

340 ax 

341 Handle of plot axis 

342 

343 """ 

344 import matplotlib.pyplot as plt 

345 from matplotlib import ticker 

346 from mpl_toolkits.axes_grid1 import make_axes_locatable 

347 

348 # Make figure if no fig or axis are passed 

349 if (fig is None) and (ax is None): 

350 fig = plt.figure() 

351 

352 # Create axis if no axis is passed. 

353 if ax is None: 

354 ax = fig.add_subplot(1, 1, 1) 

355 

356 # Get time indices 

357 if time_lims is not None: 

358 tinds = np.logical_and(time_vect >= time_lims[0], time_vect <= time_lims[1]) 

359 else: 

360 tinds = np.ones_like(time_vect).astype(bool) 

361 

362 # Get frequency indices 

363 if freq_lims is not None: 

364 finds = np.logical_and(freq_vect >= freq_lims[0], freq_vect <= freq_lims[1]) 

365 else: 

366 finds = np.ones_like(freq_vect).astype(bool) 

367 freq_lims = (freq_vect[0], freq_vect[-1]) 

368 

369 # Make space for colourbar 

370 divider = make_axes_locatable(ax) 

371 cax = divider.append_axes('right', size='5%', pad=0.05) 

372 

373 if vmax is None: 

374 vmax = np.max(hht[np.ix_(finds, tinds)]) 

375 

376 # Make main plot 

377 pcm = ax.pcolormesh(time_vect[tinds], freq_vect[finds], hht[np.ix_(finds, tinds)], 

378 vmin=vmin, vmax=vmax, cmap=cmap, shading='nearest') 

379 

380 # Set labels 

381 ax.set_xlabel('Time') 

382 ax.set_ylabel('Frequency') 

383 ax.set_title('Hilbert-Huang Transform') 

384 

385 # Scale axes if requestedd 

386 if log_y: 

387 ax.set_yscale('log') 

388 ax.set_yticks((_get_log_tickpos(freq_lims[0], freq_lims[1]))) 

389 ax.get_yaxis().set_major_formatter(ticker.ScalarFormatter()) 

390 

391 # Add colourbar 

392 plt.colorbar(pcm, cax=cax, orientation='vertical') 

393 

394 return ax 

395 

396 

397def plot_holospectrum(holo, freq_vect, am_freq_vect, 

398 freq_lims=None, am_freq_lims=None, 

399 log_x=False, log_y=False, 

400 vmin=0, vmax=None, 

401 fig=None, ax=None, cmap='hot_r', mask=True): 

402 """Create a quick summary plot for a Holospectrum. 

403 

404 Parameters 

405 ---------- 

406 holo : 2d array 

407 Hilbert-Huang spectrum to be plotted - output from emd.spectra.holospectrum 

408 freq_vect : vector 

409 Vector of frequency values for first-layer 

410 am_freq_vect : vector 

411 Vector of frequency values for amplitude modulations in second--layer 

412 freq_lims : optional tuple or list (start_val, end_val) 

413 Optional time-limits to zoom in frequency on the y-axis 

414 am_freq_lims : optional tuple or list (start_val, end_val) 

415 Optional time-limits to zoom in amplitude modulation frequency on the x-axis 

416 log_x : bool 

417 Flag indicating whether to set log-scale on x-axis 

418 log_y : bool 

419 Flag indicating whether to set log-scale on y-axis 

420 fig : optional figure handle 

421 Figure to plot inside 

422 ax : optional axis handle 

423 Axis to plot inside 

424 cmap : optional str or matplotlib.cm 

425 Colormap specification 

426 

427 Returns 

428 ------- 

429 ax 

430 Handle of plot axis 

431 

432 """ 

433 import matplotlib.pyplot as plt 

434 from matplotlib import ticker 

435 from mpl_toolkits.axes_grid1 import make_axes_locatable 

436 

437 # Make figure if no fig or axis are passed 

438 if (fig is None) and (ax is None): 

439 fig = plt.figure() 

440 

441 # Create axis if no axis is passed. 

442 if ax is None: 

443 ax = fig.add_subplot(1, 1, 1) 

444 

445 # Get frequency indices 

446 if freq_lims is not None: 

447 finds = np.logical_and(freq_vect > freq_lims[0], freq_vect < freq_lims[1]) 

448 else: 

449 finds = np.ones_like(freq_vect).astype(bool) 

450 

451 # Get frequency indices 

452 if am_freq_lims is not None: 

453 am_finds = np.logical_and(am_freq_vect > am_freq_lims[0], am_freq_vect < am_freq_lims[1]) 

454 else: 

455 am_finds = np.ones_like(am_freq_vect).astype(bool) 

456 

457 plot_holo = holo.copy() 

458 if mask: 

459 for ii in range(len(freq_vect)): 

460 for jj in range(len(am_freq_vect)): 

461 if freq_vect[ii] < am_freq_vect[jj]: 

462 plot_holo[jj, ii] = np.nan 

463 

464 # Set colourmap 

465 if isinstance(cmap, str): 

466 cmap = getattr(plt.cm, cmap) 

467 elif cmap is None: 

468 cmap = getattr(plt.cm, cmap) 

469 

470 # Set mask values in colourmap 

471 cmap.set_bad([0.8, 0.8, 0.8]) 

472 

473 # Make space for colourbar 

474 divider = make_axes_locatable(ax) 

475 cax = divider.append_axes('right', size='5%', pad=0.05) 

476 

477 if vmax is None: 

478 vmax = np.max(plot_holo[np.ix_(am_finds, finds)]) 

479 

480 # Make main plot 

481 pcm = ax.pcolormesh(am_freq_vect[am_finds], freq_vect[finds], plot_holo[np.ix_(am_finds, finds)].T, 

482 cmap=cmap, vmin=vmin, vmax=vmax, shading='nearest') 

483 

484 # Set labels 

485 ax.set_xlabel('Amplitude Modulation Frequency') 

486 ax.set_ylabel('Carrier Wave Frequency') 

487 ax.set_title('Holospectrum') 

488 

489 # Scale axes if requestedd 

490 if log_y: 

491 ax.set_yscale('log') 

492 ax.set_yticks((_get_log_tickpos(freq_lims[0], freq_lims[1]))) 

493 ax.get_yaxis().set_major_formatter(ticker.ScalarFormatter()) 

494 

495 if log_x: 

496 ax.set_xscale('log') 

497 ax.set_xticks((_get_log_tickpos(am_freq_lims[0], am_freq_lims[1]))) 

498 ax.get_xaxis().set_major_formatter(ticker.ScalarFormatter()) 

499 

500 # Add colourbar 

501 plt.colorbar(pcm, cax=cax, orientation='vertical') 

502 

503 return ax 

504 

505 

506def _get_log_tickpos(lo, hi, tick_rate=5, round_vals=True): 

507 """Generate tick positions for log-scales. 

508 

509 Parameters 

510 ---------- 

511 lo : float 

512 Low end of frequency range 

513 hi : float 

514 High end of frequency range 

515 tick_rate : int 

516 Number of ticks per order-of-magnitude 

517 round_vals : bool 

518 Flag indicating whether ticks should be rounded to first non-zero value. 

519 

520 Returns 

521 ------- 

522 ndarray 

523 Vector of tick positions 

524 

525 """ 

526 lo_oom = np.floor(np.log10(lo)).astype(int) 

527 hi_oom = np.ceil(np.log10(hi)).astype(int) + 1 

528 ticks = [] 

529 log_tick_pos_inds = np.round(np.logspace(1, 2, tick_rate)).astype(int) - 1 

530 for ii in range(lo_oom, hi_oom): 

531 tks = np.linspace(10**ii, 10**(ii+1), 100)[log_tick_pos_inds] 

532 if round_vals: 

533 ticks.append(np.round(tks / 10**ii)*10**ii) 

534 else: 

535 ticks.append(tks) 

536 #ticks.append(np.logspace(ii, ii+1, tick_rate)) 

537 

538 ticks = np.unique(np.r_[ticks]) 

539 inds = np.logical_and(ticks > lo, ticks < hi) 

540 return ticks[inds]