Coverage for emd/plotting.py: 5%
220 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-09 10:07 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-09 10:07 +0000
1#!/usr/bin/python
3# vim: set expandtab ts=4 sw=4:
5"""
6Routines for plotting results of EMD analyses.
8Main Routines:
9 plot_imfs
10 plot_hilberthuang
11 plot_holospectrum
13Utilities:
14 _get_log_tickpos
16"""
18import logging
19from functools import partial
21import numpy as np
23# Housekeeping for logging
24logger = logging.getLogger(__name__)
27def plot_imfs(imfs, time_vect=None, X=None, step=4, sample_rate=1, sharey=True,
28 scale_y=None, cmap=True, fig=None, fig_args=None, ax=None,
29 xlabel='Time (samples)', ylabel_args=None, ylabel_xoffset=-0.08,
30 tick_params=None):
31 """Create a quick summary plot for a set of IMFs.
33 Parameters
34 ----------
35 imfs : ndarray
36 2D array of IMFs to plot
37 time_vect : ndarray
38 Optional 1D array specifying time values (Default value = None)
39 X : ndarray
40 Original data prior to decomposition. If passed, this will be plotted
41 on the top row rather than the sum of the IMFs. Useful for visualising
42 incomplete sets of IMFs.
43 step : float
44 Scaling factor determining spacing between IMF subaxes, approximately
45 corresponds to the z-value of the y-axis extremeties for each IMF. If
46 there is substantial overlap between IMFs, this value can be increased
47 to compensate.
48 sample_rate : float
49 Optional sample rate to determine time axis values if time_vect is not
50 specified if time_vect is given.
51 sharey: Boolean
52 Flag indicating whether the y-axis should be adaptive to each mode
53 (False) or consistent across modes (True) (Default value = True)
54 cmap : {None,True,matplotlib colormap}
55 Optional colourmap to use. None will plot each IMF in black and True will
56 use the plt.cm.Dark2 colormap as default. A different colormap may also
57 be passed in.
58 fig : matplotlib figure instance
59 Optional figure to make the plot in.
60 fig_args : dict
61 Dictionary of kwargs to pass to plt.figure, only used if 'fig' is not passed.
62 ax : matplotlib axes instance
63 Optional axes to make the plot in.
64 xlabel : str
65 Optional x-axis label. Defaults to 'Time (samples)'
66 ylabel_args : dict
67 Optional arguments to be passed to plt.text to create the y-axis
68 labels. Defaults to {'ha': 'center', 'va': 'center', 'fontsize': 14}.
69 These values remain unless explicitly overridden.
70 ylabel_xoffset : float
71 Optional axis offset to fine-tune the position of the y axis labels.
72 Defaults to -0.08 and typically only needs VERY minor adjustment.
73 tick_params : dict
74 Optional arguments passed to plt.tick_params to style the tick labels.
75 Defaults to {'axis': 'both', 'which': 'major', 'fontsize': 10}.
76 These values remain unless explicitly overridden.
78 """
79 import matplotlib.pyplot as plt
80 from matplotlib.colors import Colormap
81 from scipy import stats
82 if scale_y is not None:
83 logger.warning("The argument 'scale_y' is depreciated and will be \
84 removed in a future version. Please use 'sharey' to remove this \
85 warning")
86 sharey = False if scale_y is True else True
88 if time_vect is None:
89 time_vect = np.linspace(0, imfs.shape[0]/sample_rate, imfs.shape[0])
91 # Set y-axis label arguments
92 ylabel_args = {} if ylabel_args is None else ylabel_args
93 ylabel_args.setdefault('ha', 'center')
94 ylabel_args.setdefault('va', 'center')
95 ylabel_args.setdefault('fontsize', 14)
97 # Set axis tick label arguments
98 tick_params = {} if tick_params is None else tick_params
99 tick_params.setdefault('axis', 'both')
100 tick_params.setdefault('which', 'major')
101 tick_params.setdefault('labelsize', 10)
103 top_label = 'Summed\nIMFs' if X is None else 'Raw\nSignal'
104 X = imfs.sum(axis=1) if X is None else X
106 order_of_magnitude = int(np.floor(np.log(X.std())))
107 round_scale = -order_of_magnitude if order_of_magnitude < 0 else 12
109 # Everything is z-transformed internally to make positioning in the axis
110 # simpler. We either z-transform relative to summed signal or for each imf
111 # in turn. Also divide per-imf scaled data by 2 to reduce overlap as
112 # z-transform relative to full signal will naturally give smaller ranges.
113 #
114 # Either way - Scale based on variance of raw data - don't touch the mean.
115 if sharey is False:
116 def scale_func(x):
117 return (stats.zscore(x) + x.mean()) / 2
118 else:
119 scale_func = partial(stats.zmap, compare=X-X.mean())
121 if fig is None and ax is None:
122 if fig_args is None:
123 fig_args = {'figsize': (16, 10)}
124 fig = plt.figure(**fig_args)
125 plt.subplots_adjust(top=0.975, right=0.975)
127 plt.tick_params(**tick_params)
129 if ax is None:
130 ax = plt.subplot(111)
132 if cmap is True:
133 # Use default colormap
134 cmap = plt.cm.Dark2
135 cols = cmap(np.linspace(0, 1, imfs.shape[1] + 1))
136 elif isinstance(cmap, Colormap):
137 # Use specified colormap
138 cols = cmap(np.linspace(0, 1, imfs.shape[1] + 1))
139 else:
140 # Use all black lines - this is overall default
141 cols = np.array([[0, 0, 0] for ii in range(imfs.shape[1] + 1)])
143 # Blended transform uses axis coords for X and data coords for Y
144 import matplotlib.transforms as transforms
145 trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)
147 # Initialise tick lists
148 yticks = []
149 yticklabels = []
151 # Plot full time-series
152 first_step = 0
153 ax.plot((time_vect[0], time_vect[0]), ((first_step)-1.5, (first_step)+1.5), 'k')
154 ax.plot(time_vect, np.zeros_like(time_vect) + first_step,
155 lw=0.5, color=[0.8, 0.8, 0.8])
156 ax.plot(time_vect, scale_func(X)+first_step, 'k')
157 ax.text(ylabel_xoffset, first_step, top_label,
158 transform=trans, **ylabel_args)
160 # Set y-axis for full time-series
161 lim = np.round(1.5 * X.std(), 2)
162 yticks_imf = _get_sensible_ticks(lim)
163 ytickpos = yticks_imf / X.std()
165 yticks.extend(first_step+ytickpos)
166 yticklabels.extend(np.round(yticks_imf, round_scale))
168 # Main IMF loop
169 for ii in range(imfs.shape[1]):
170 this_step = (ii+1)*step
172 # Plot IMF and axis lines
173 ax.plot(time_vect, np.zeros_like(time_vect) - this_step,
174 lw=0.5, color=[0.8, 0.8, 0.8])
175 ax.plot((time_vect[0], time_vect[0]), (-this_step-1.5, -this_step+1.5), 'k')
176 ax.plot(time_vect, scale_func(imfs[:, ii]) - this_step, color=cols[ii, :])
178 # Compute ticks
179 if scale_y:
180 lim = 1.5 * imfs[:, ii].std()
181 yticks_imf = _get_sensible_ticks(lim)
182 ytickpos = yticks_imf / imfs[:, ii].std()
183 else:
184 lim = 1.5 * X.std()
185 yticks_imf = _get_sensible_ticks(lim)
186 ytickpos = yticks_imf / X.std()
188 yticks.extend(-this_step+ytickpos)
189 yticklabels.extend(np.round(yticks_imf, round_scale))
191 # Add label
192 ax.text(ylabel_xoffset, -this_step, 'IMF-{}'.format(ii+1),
193 transform=trans, **ylabel_args)
195 # Hide unwanted spines
196 for tag in ['left', 'top', 'right']:
197 ax.spines[tag].set_visible(False)
198 ymax = np.max(scale_func(X)+step/2)
200 # Set axis limits
201 ax.set_ylim(np.min(yticks)-1, ymax)
202 ax.set_xlim(time_vect[0], time_vect[-1])
204 # Set axis ticks
205 ax.set_yticks(yticks)
206 ax.set_yticklabels(yticklabels)
208 ax.set_xlabel(xlabel, fontsize=ylabel_args.get('fontsize', 14))
210 return ax
213def _get_sensible_ticks(lim, nbins=3):
214 """Return sensibly rounded tick positions based on a plotting range.
216 Based on code in matplotlib.ticker
217 Assuming symmetrical axes and 3 ticks for the moment
219 """
220 from matplotlib import ticker
221 scale, offset = ticker.scale_range(-lim, lim)
222 if lim/scale > 0.5:
223 scale = scale / 2
224 edge = ticker._Edge_integer(scale, offset)
225 low = edge.ge(-lim)
226 high = edge.le(lim)
228 ticks = np.linspace(low, high, nbins) * scale
230 return ticks
233def plot_imfs_depreciated(imfs, time_vect=None, sample_rate=1, scale_y=False, freqs=None, cmap=None, fig=None):
234 """Create a quick summary plot for a set of IMFs.
236 Parameters
237 ----------
238 imfs : ndarray
239 2D array of IMFs to plot
240 time_vect : ndarray
241 Optional 1D array specifying time values (Default value = None)
242 sample_rate : float
243 Optional sample rate to determine time axis values if time_vect is not
244 specified if time_vect is given.
245 scale_y : Boolean
246 Flag indicating whether the y-axis should be adative to each mode
247 (False) or consistent across modes (True) (Default value = False)
248 freqs : array_like
249 Optional vector of frequencies for each IMF
250 cmap : {None,True,matplotlib colormap}
251 Optional colourmap to use. None will plot each IMF in black and True will
252 use the plt.cm.Dark2 colormap as default. A different colormap may also
253 be passed in.
255 """
256 import matplotlib.pyplot as plt
257 from matplotlib.colors import Colormap
258 nplots = imfs.shape[1] + 1
259 if time_vect is None:
260 time_vect = np.linspace(0, imfs.shape[0]/sample_rate, imfs.shape[0])
262 mx = np.abs(imfs).max()
263 mx_sig = np.abs(imfs.sum(axis=1)).max()
265 if fig is None:
266 fig = plt.figure()
268 ax = fig.add_subplot(nplots, 1, 1)
269 if scale_y:
270 ax.yaxis.get_major_locator().set_params(integer=True)
271 for tag in ['top', 'right', 'bottom']:
272 ax.spines[tag].set_visible(False)
273 ax.plot((time_vect[0], time_vect[-1]), (0, 0), color=[.5, .5, .5])
274 ax.plot(time_vect, imfs.sum(axis=1), 'k')
275 ax.tick_params(axis='x', labelbottom=False)
276 ax.set_xlim(time_vect[0], time_vect[-1])
277 ax.set_ylim(-mx_sig * 1.1, mx_sig * 1.1)
278 ax.set_ylabel('Signal', rotation=0, labelpad=10)
280 if cmap is True:
281 # Use default colormap
282 cmap = plt.cm.Dark2
283 cols = cmap(np.linspace(0, 1, imfs.shape[1] + 1))
284 elif isinstance(cmap, Colormap):
285 # Use specified colormap
286 cols = cmap(np.linspace(0, 1, imfs.shape[1] + 1))
287 else:
288 # Use all black lines - this is overall default
289 cols = np.array([[0, 0, 0] for ii in range(imfs.shape[1] + 1)])
291 for ii in range(1, nplots):
292 ax = fig.add_subplot(nplots, 1, ii + 1)
293 for tag in ['top', 'right', 'bottom']:
294 ax.spines[tag].set_visible(False)
295 ax.plot((time_vect[0], time_vect[-1]), (0, 0), color=[.5, .5, .5])
296 ax.plot(time_vect, imfs[:, ii - 1], color=cols[ii, :])
297 ax.set_xlim(time_vect[0], time_vect[-1])
298 if scale_y:
299 ax.set_ylim(-mx * 1.1, mx * 1.1)
300 ax.yaxis.get_major_locator().set_params(integer=True)
301 ax.set_ylabel('IMF {0}'.format(ii), rotation=0, labelpad=10)
303 if ii < nplots - 1:
304 ax.tick_params(axis='x', labelbottom=False)
305 else:
306 ax.set_xlabel('Time')
307 if freqs is not None:
308 ax.set_title(freqs[ii - 1], fontsize=8)
310 fig.subplots_adjust(top=.95, bottom=.1, left=.2, right=.99)
313def plot_hilberthuang(hht, time_vect, freq_vect,
314 time_lims=None, freq_lims=None, log_y=False,
315 vmin=0, vmax=None,
316 fig=None, ax=None, cmap='hot_r'):
317 """Create a quick summary plot for a Hilbert-Huang Transform.
319 Parameters
320 ----------
321 hht : 2d array
322 Hilbert-Huang spectrum to be plotted - output from emd.spectra.hilberthuang
323 time_vect : vector
324 Vector of time samples
325 freq_vect : vector
326 Vector of frequency bins
327 time_lims : optional tuple or list (start_val, end_val)
328 Optional time-limits to zoom in time on the x-axis
329 freq_lims : optional tuple or list (start_val, end_val)
330 Optional time-limits to zoom in frequency on the y-axis
331 fig : optional figure handle
332 Figure to plot inside
333 ax : optional axis handle
334 Axis to plot inside
335 cmap : optional str or matplotlib.cm
336 Colormap specification
338 Returns
339 -------
340 ax
341 Handle of plot axis
343 """
344 import matplotlib.pyplot as plt
345 from matplotlib import ticker
346 from mpl_toolkits.axes_grid1 import make_axes_locatable
348 # Make figure if no fig or axis are passed
349 if (fig is None) and (ax is None):
350 fig = plt.figure()
352 # Create axis if no axis is passed.
353 if ax is None:
354 ax = fig.add_subplot(1, 1, 1)
356 # Get time indices
357 if time_lims is not None:
358 tinds = np.logical_and(time_vect >= time_lims[0], time_vect <= time_lims[1])
359 else:
360 tinds = np.ones_like(time_vect).astype(bool)
362 # Get frequency indices
363 if freq_lims is not None:
364 finds = np.logical_and(freq_vect >= freq_lims[0], freq_vect <= freq_lims[1])
365 else:
366 finds = np.ones_like(freq_vect).astype(bool)
367 freq_lims = (freq_vect[0], freq_vect[-1])
369 # Make space for colourbar
370 divider = make_axes_locatable(ax)
371 cax = divider.append_axes('right', size='5%', pad=0.05)
373 if vmax is None:
374 vmax = np.max(hht[np.ix_(finds, tinds)])
376 # Make main plot
377 pcm = ax.pcolormesh(time_vect[tinds], freq_vect[finds], hht[np.ix_(finds, tinds)],
378 vmin=vmin, vmax=vmax, cmap=cmap, shading='nearest')
380 # Set labels
381 ax.set_xlabel('Time')
382 ax.set_ylabel('Frequency')
383 ax.set_title('Hilbert-Huang Transform')
385 # Scale axes if requestedd
386 if log_y:
387 ax.set_yscale('log')
388 ax.set_yticks((_get_log_tickpos(freq_lims[0], freq_lims[1])))
389 ax.get_yaxis().set_major_formatter(ticker.ScalarFormatter())
391 # Add colourbar
392 plt.colorbar(pcm, cax=cax, orientation='vertical')
394 return ax
397def plot_holospectrum(holo, freq_vect, am_freq_vect,
398 freq_lims=None, am_freq_lims=None,
399 log_x=False, log_y=False,
400 vmin=0, vmax=None,
401 fig=None, ax=None, cmap='hot_r', mask=True):
402 """Create a quick summary plot for a Holospectrum.
404 Parameters
405 ----------
406 holo : 2d array
407 Hilbert-Huang spectrum to be plotted - output from emd.spectra.holospectrum
408 freq_vect : vector
409 Vector of frequency values for first-layer
410 am_freq_vect : vector
411 Vector of frequency values for amplitude modulations in second--layer
412 freq_lims : optional tuple or list (start_val, end_val)
413 Optional time-limits to zoom in frequency on the y-axis
414 am_freq_lims : optional tuple or list (start_val, end_val)
415 Optional time-limits to zoom in amplitude modulation frequency on the x-axis
416 log_x : bool
417 Flag indicating whether to set log-scale on x-axis
418 log_y : bool
419 Flag indicating whether to set log-scale on y-axis
420 fig : optional figure handle
421 Figure to plot inside
422 ax : optional axis handle
423 Axis to plot inside
424 cmap : optional str or matplotlib.cm
425 Colormap specification
427 Returns
428 -------
429 ax
430 Handle of plot axis
432 """
433 import matplotlib.pyplot as plt
434 from matplotlib import ticker
435 from mpl_toolkits.axes_grid1 import make_axes_locatable
437 # Make figure if no fig or axis are passed
438 if (fig is None) and (ax is None):
439 fig = plt.figure()
441 # Create axis if no axis is passed.
442 if ax is None:
443 ax = fig.add_subplot(1, 1, 1)
445 # Get frequency indices
446 if freq_lims is not None:
447 finds = np.logical_and(freq_vect > freq_lims[0], freq_vect < freq_lims[1])
448 else:
449 finds = np.ones_like(freq_vect).astype(bool)
451 # Get frequency indices
452 if am_freq_lims is not None:
453 am_finds = np.logical_and(am_freq_vect > am_freq_lims[0], am_freq_vect < am_freq_lims[1])
454 else:
455 am_finds = np.ones_like(am_freq_vect).astype(bool)
457 plot_holo = holo.copy()
458 if mask:
459 for ii in range(len(freq_vect)):
460 for jj in range(len(am_freq_vect)):
461 if freq_vect[ii] < am_freq_vect[jj]:
462 plot_holo[jj, ii] = np.nan
464 # Set colourmap
465 if isinstance(cmap, str):
466 cmap = getattr(plt.cm, cmap)
467 elif cmap is None:
468 cmap = getattr(plt.cm, cmap)
470 # Set mask values in colourmap
471 cmap.set_bad([0.8, 0.8, 0.8])
473 # Make space for colourbar
474 divider = make_axes_locatable(ax)
475 cax = divider.append_axes('right', size='5%', pad=0.05)
477 if vmax is None:
478 vmax = np.max(plot_holo[np.ix_(am_finds, finds)])
480 # Make main plot
481 pcm = ax.pcolormesh(am_freq_vect[am_finds], freq_vect[finds], plot_holo[np.ix_(am_finds, finds)].T,
482 cmap=cmap, vmin=vmin, vmax=vmax, shading='nearest')
484 # Set labels
485 ax.set_xlabel('Amplitude Modulation Frequency')
486 ax.set_ylabel('Carrier Wave Frequency')
487 ax.set_title('Holospectrum')
489 # Scale axes if requestedd
490 if log_y:
491 ax.set_yscale('log')
492 ax.set_yticks((_get_log_tickpos(freq_lims[0], freq_lims[1])))
493 ax.get_yaxis().set_major_formatter(ticker.ScalarFormatter())
495 if log_x:
496 ax.set_xscale('log')
497 ax.set_xticks((_get_log_tickpos(am_freq_lims[0], am_freq_lims[1])))
498 ax.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
500 # Add colourbar
501 plt.colorbar(pcm, cax=cax, orientation='vertical')
503 return ax
506def _get_log_tickpos(lo, hi, tick_rate=5, round_vals=True):
507 """Generate tick positions for log-scales.
509 Parameters
510 ----------
511 lo : float
512 Low end of frequency range
513 hi : float
514 High end of frequency range
515 tick_rate : int
516 Number of ticks per order-of-magnitude
517 round_vals : bool
518 Flag indicating whether ticks should be rounded to first non-zero value.
520 Returns
521 -------
522 ndarray
523 Vector of tick positions
525 """
526 lo_oom = np.floor(np.log10(lo)).astype(int)
527 hi_oom = np.ceil(np.log10(hi)).astype(int) + 1
528 ticks = []
529 log_tick_pos_inds = np.round(np.logspace(1, 2, tick_rate)).astype(int) - 1
530 for ii in range(lo_oom, hi_oom):
531 tks = np.linspace(10**ii, 10**(ii+1), 100)[log_tick_pos_inds]
532 if round_vals:
533 ticks.append(np.round(tks / 10**ii)*10**ii)
534 else:
535 ticks.append(tks)
536 #ticks.append(np.logspace(ii, ii+1, tick_rate))
538 ticks = np.unique(np.r_[ticks])
539 inds = np.logical_and(ticks > lo, ticks < hi)
540 return ticks[inds]