Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import numpy as np 

2from scipy.linalg import eig 

3from scipy.special import comb 

4from scipy.signal import convolve 

5 

6__all__ = ['daub', 'qmf', 'cascade', 'morlet', 'ricker', 'morlet2', 'cwt'] 

7 

8 

9def daub(p): 

10 """ 

11 The coefficients for the FIR low-pass filter producing Daubechies wavelets. 

12 

13 p>=1 gives the order of the zero at f=1/2. 

14 There are 2p filter coefficients. 

15 

16 Parameters 

17 ---------- 

18 p : int 

19 Order of the zero at f=1/2, can have values from 1 to 34. 

20 

21 Returns 

22 ------- 

23 daub : ndarray 

24 Return 

25 

26 """ 

27 sqrt = np.sqrt 

28 if p < 1: 

29 raise ValueError("p must be at least 1.") 

30 if p == 1: 

31 c = 1 / sqrt(2) 

32 return np.array([c, c]) 

33 elif p == 2: 

34 f = sqrt(2) / 8 

35 c = sqrt(3) 

36 return f * np.array([1 + c, 3 + c, 3 - c, 1 - c]) 

37 elif p == 3: 

38 tmp = 12 * sqrt(10) 

39 z1 = 1.5 + sqrt(15 + tmp) / 6 - 1j * (sqrt(15) + sqrt(tmp - 15)) / 6 

40 z1c = np.conj(z1) 

41 f = sqrt(2) / 8 

42 d0 = np.real((1 - z1) * (1 - z1c)) 

43 a0 = np.real(z1 * z1c) 

44 a1 = 2 * np.real(z1) 

45 return f / d0 * np.array([a0, 3 * a0 - a1, 3 * a0 - 3 * a1 + 1, 

46 a0 - 3 * a1 + 3, 3 - a1, 1]) 

47 elif p < 35: 

48 # construct polynomial and factor it 

49 if p < 35: 

50 P = [comb(p - 1 + k, k, exact=1) for k in range(p)][::-1] 

51 yj = np.roots(P) 

52 else: # try different polynomial --- needs work 

53 P = [comb(p - 1 + k, k, exact=1) / 4.0**k 

54 for k in range(p)][::-1] 

55 yj = np.roots(P) / 4 

56 # for each root, compute two z roots, select the one with |z|>1 

57 # Build up final polynomial 

58 c = np.poly1d([1, 1])**p 

59 q = np.poly1d([1]) 

60 for k in range(p - 1): 

61 yval = yj[k] 

62 part = 2 * sqrt(yval * (yval - 1)) 

63 const = 1 - 2 * yval 

64 z1 = const + part 

65 if (abs(z1)) < 1: 

66 z1 = const - part 

67 q = q * [1, -z1] 

68 

69 q = c * np.real(q) 

70 # Normalize result 

71 q = q / np.sum(q) * sqrt(2) 

72 return q.c[::-1] 

73 else: 

74 raise ValueError("Polynomial factorization does not work " 

75 "well for p too large.") 

76 

77 

78def qmf(hk): 

79 """ 

80 Return high-pass qmf filter from low-pass 

81 

82 Parameters 

83 ---------- 

84 hk : array_like 

85 Coefficients of high-pass filter. 

86 

87 """ 

88 N = len(hk) - 1 

89 asgn = [{0: 1, 1: -1}[k % 2] for k in range(N + 1)] 

90 return hk[::-1] * np.array(asgn) 

91 

92 

93def cascade(hk, J=7): 

94 """ 

95 Return (x, phi, psi) at dyadic points ``K/2**J`` from filter coefficients. 

96 

97 Parameters 

98 ---------- 

99 hk : array_like 

100 Coefficients of low-pass filter. 

101 J : int, optional 

102 Values will be computed at grid points ``K/2**J``. Default is 7. 

103 

104 Returns 

105 ------- 

106 x : ndarray 

107 The dyadic points ``K/2**J`` for ``K=0...N * (2**J)-1`` where 

108 ``len(hk) = len(gk) = N+1``. 

109 phi : ndarray 

110 The scaling function ``phi(x)`` at `x`: 

111 ``phi(x) = sum(hk * phi(2x-k))``, where k is from 0 to N. 

112 psi : ndarray, optional 

113 The wavelet function ``psi(x)`` at `x`: 

114 ``phi(x) = sum(gk * phi(2x-k))``, where k is from 0 to N. 

115 `psi` is only returned if `gk` is not None. 

116 

117 Notes 

118 ----- 

119 The algorithm uses the vector cascade algorithm described by Strang and 

120 Nguyen in "Wavelets and Filter Banks". It builds a dictionary of values 

121 and slices for quick reuse. Then inserts vectors into final vector at the 

122 end. 

123 

124 """ 

125 N = len(hk) - 1 

126 

127 if (J > 30 - np.log2(N + 1)): 

128 raise ValueError("Too many levels.") 

129 if (J < 1): 

130 raise ValueError("Too few levels.") 

131 

132 # construct matrices needed 

133 nn, kk = np.ogrid[:N, :N] 

134 s2 = np.sqrt(2) 

135 # append a zero so that take works 

136 thk = np.r_[hk, 0] 

137 gk = qmf(hk) 

138 tgk = np.r_[gk, 0] 

139 

140 indx1 = np.clip(2 * nn - kk, -1, N + 1) 

141 indx2 = np.clip(2 * nn - kk + 1, -1, N + 1) 

142 m = np.zeros((2, 2, N, N), 'd') 

143 m[0, 0] = np.take(thk, indx1, 0) 

144 m[0, 1] = np.take(thk, indx2, 0) 

145 m[1, 0] = np.take(tgk, indx1, 0) 

146 m[1, 1] = np.take(tgk, indx2, 0) 

147 m *= s2 

148 

149 # construct the grid of points 

150 x = np.arange(0, N * (1 << J), dtype=float) / (1 << J) 

151 phi = 0 * x 

152 

153 psi = 0 * x 

154 

155 # find phi0, and phi1 

156 lam, v = eig(m[0, 0]) 

157 ind = np.argmin(np.absolute(lam - 1)) 

158 # a dictionary with a binary representation of the 

159 # evaluation points x < 1 -- i.e. position is 0.xxxx 

160 v = np.real(v[:, ind]) 

161 # need scaling function to integrate to 1 so find 

162 # eigenvector normalized to sum(v,axis=0)=1 

163 sm = np.sum(v) 

164 if sm < 0: # need scaling function to integrate to 1 

165 v = -v 

166 sm = -sm 

167 bitdic = {'0': v / sm} 

168 bitdic['1'] = np.dot(m[0, 1], bitdic['0']) 

169 step = 1 << J 

170 phi[::step] = bitdic['0'] 

171 phi[(1 << (J - 1))::step] = bitdic['1'] 

172 psi[::step] = np.dot(m[1, 0], bitdic['0']) 

173 psi[(1 << (J - 1))::step] = np.dot(m[1, 1], bitdic['0']) 

174 # descend down the levels inserting more and more values 

175 # into bitdic -- store the values in the correct location once we 

176 # have computed them -- stored in the dictionary 

177 # for quicker use later. 

178 prevkeys = ['1'] 

179 for level in range(2, J + 1): 

180 newkeys = ['%d%s' % (xx, yy) for xx in [0, 1] for yy in prevkeys] 

181 fac = 1 << (J - level) 

182 for key in newkeys: 

183 # convert key to number 

184 num = 0 

185 for pos in range(level): 

186 if key[pos] == '1': 

187 num += (1 << (level - 1 - pos)) 

188 pastphi = bitdic[key[1:]] 

189 ii = int(key[0]) 

190 temp = np.dot(m[0, ii], pastphi) 

191 bitdic[key] = temp 

192 phi[num * fac::step] = temp 

193 psi[num * fac::step] = np.dot(m[1, ii], pastphi) 

194 prevkeys = newkeys 

195 

196 return x, phi, psi 

197 

198 

199def morlet(M, w=5.0, s=1.0, complete=True): 

200 """ 

201 Complex Morlet wavelet. 

202 

203 Parameters 

204 ---------- 

205 M : int 

206 Length of the wavelet. 

207 w : float, optional 

208 Omega0. Default is 5 

209 s : float, optional 

210 Scaling factor, windowed from ``-s*2*pi`` to ``+s*2*pi``. Default is 1. 

211 complete : bool, optional 

212 Whether to use the complete or the standard version. 

213 

214 Returns 

215 ------- 

216 morlet : (M,) ndarray 

217 

218 See Also 

219 -------- 

220 morlet2 : Implementation of Morlet wavelet, compatible with `cwt`. 

221 scipy.signal.gausspulse 

222 

223 Notes 

224 ----- 

225 The standard version:: 

226 

227 pi**-0.25 * exp(1j*w*x) * exp(-0.5*(x**2)) 

228 

229 This commonly used wavelet is often referred to simply as the 

230 Morlet wavelet. Note that this simplified version can cause 

231 admissibility problems at low values of `w`. 

232 

233 The complete version:: 

234 

235 pi**-0.25 * (exp(1j*w*x) - exp(-0.5*(w**2))) * exp(-0.5*(x**2)) 

236 

237 This version has a correction 

238 term to improve admissibility. For `w` greater than 5, the 

239 correction term is negligible. 

240 

241 Note that the energy of the return wavelet is not normalised 

242 according to `s`. 

243 

244 The fundamental frequency of this wavelet in Hz is given 

245 by ``f = 2*s*w*r / M`` where `r` is the sampling rate. 

246 

247 Note: This function was created before `cwt` and is not compatible 

248 with it. 

249 

250 """ 

251 x = np.linspace(-s * 2 * np.pi, s * 2 * np.pi, M) 

252 output = np.exp(1j * w * x) 

253 

254 if complete: 

255 output -= np.exp(-0.5 * (w**2)) 

256 

257 output *= np.exp(-0.5 * (x**2)) * np.pi**(-0.25) 

258 

259 return output 

260 

261 

262def ricker(points, a): 

263 """ 

264 Return a Ricker wavelet, also known as the "Mexican hat wavelet". 

265 

266 It models the function: 

267 

268 ``A * (1 - (x/a)**2) * exp(-0.5*(x/a)**2)``, 

269 

270 where ``A = 2/(sqrt(3*a)*(pi**0.25))``. 

271 

272 Parameters 

273 ---------- 

274 points : int 

275 Number of points in `vector`. 

276 Will be centered around 0. 

277 a : scalar 

278 Width parameter of the wavelet. 

279 

280 Returns 

281 ------- 

282 vector : (N,) ndarray 

283 Array of length `points` in shape of ricker curve. 

284 

285 Examples 

286 -------- 

287 >>> from scipy import signal 

288 >>> import matplotlib.pyplot as plt 

289 

290 >>> points = 100 

291 >>> a = 4.0 

292 >>> vec2 = signal.ricker(points, a) 

293 >>> print(len(vec2)) 

294 100 

295 >>> plt.plot(vec2) 

296 >>> plt.show() 

297 

298 """ 

299 A = 2 / (np.sqrt(3 * a) * (np.pi**0.25)) 

300 wsq = a**2 

301 vec = np.arange(0, points) - (points - 1.0) / 2 

302 xsq = vec**2 

303 mod = (1 - xsq / wsq) 

304 gauss = np.exp(-xsq / (2 * wsq)) 

305 total = A * mod * gauss 

306 return total 

307 

308 

309def morlet2(M, s, w=5): 

310 """ 

311 Complex Morlet wavelet, designed to work with `cwt`. 

312 

313 Returns the complete version of morlet wavelet, normalised 

314 according to `s`:: 

315 

316 exp(1j*w*x/s) * exp(-0.5*(x/s)**2) * pi**(-0.25) * sqrt(1/s) 

317 

318 Parameters 

319 ---------- 

320 M : int 

321 Length of the wavelet. 

322 s : float 

323 Width parameter of the wavelet. 

324 w : float, optional 

325 Omega0. Default is 5 

326 

327 Returns 

328 ------- 

329 morlet : (M,) ndarray 

330 

331 See Also 

332 -------- 

333 morlet : Implementation of Morlet wavelet, incompatible with `cwt` 

334 

335 Notes 

336 ----- 

337 

338 .. versionadded:: 1.4.0 

339 

340 This function was designed to work with `cwt`. Because `morlet2` 

341 returns an array of complex numbers, the `dtype` argument of `cwt` 

342 should be set to `complex128` for best results. 

343 

344 Note the difference in implementation with `morlet`. 

345 The fundamental frequency of this wavelet in Hz is given by:: 

346 

347 f = w*fs / (2*s*np.pi) 

348 

349 where ``fs`` is the sampling rate and `s` is the wavelet width parameter. 

350 Similarly we can get the wavelet width parameter at ``f``:: 

351 

352 s = w*fs / (2*f*np.pi) 

353 

354 Examples 

355 -------- 

356 >>> from scipy import signal 

357 >>> import matplotlib.pyplot as plt 

358 

359 >>> M = 100 

360 >>> s = 4.0 

361 >>> w = 2.0 

362 >>> wavelet = signal.morlet2(M, s, w) 

363 >>> plt.plot(abs(wavelet)) 

364 >>> plt.show() 

365 

366 This example shows basic use of `morlet2` with `cwt` in time-frequency 

367 analysis: 

368 

369 >>> from scipy import signal 

370 >>> import matplotlib.pyplot as plt 

371 >>> t, dt = np.linspace(0, 1, 200, retstep=True) 

372 >>> fs = 1/dt 

373 >>> w = 6. 

374 >>> sig = np.cos(2*np.pi*(50 + 10*t)*t) + np.sin(40*np.pi*t) 

375 >>> freq = np.linspace(1, fs/2, 100) 

376 >>> widths = w*fs / (2*freq*np.pi) 

377 >>> cwtm = signal.cwt(sig, signal.morlet2, widths, w=w) 

378 >>> plt.pcolormesh(t, freq, np.abs(cwtm), cmap='viridis', shading='gouraud') 

379 >>> plt.show() 

380 

381 """ 

382 x = np.arange(0, M) - (M - 1.0) / 2 

383 x = x / s 

384 wavelet = np.exp(1j * w * x) * np.exp(-0.5 * x**2) * np.pi**(-0.25) 

385 output = np.sqrt(1/s) * wavelet 

386 return output 

387 

388 

389def cwt(data, wavelet, widths, dtype=None, **kwargs): 

390 """ 

391 Continuous wavelet transform. 

392 

393 Performs a continuous wavelet transform on `data`, 

394 using the `wavelet` function. A CWT performs a convolution 

395 with `data` using the `wavelet` function, which is characterized 

396 by a width parameter and length parameter. The `wavelet` function 

397 is allowed to be complex. 

398 

399 Parameters 

400 ---------- 

401 data : (N,) ndarray 

402 data on which to perform the transform. 

403 wavelet : function 

404 Wavelet function, which should take 2 arguments. 

405 The first argument is the number of points that the returned vector 

406 will have (len(wavelet(length,width)) == length). 

407 The second is a width parameter, defining the size of the wavelet 

408 (e.g. standard deviation of a gaussian). See `ricker`, which 

409 satisfies these requirements. 

410 widths : (M,) sequence 

411 Widths to use for transform. 

412 dtype : data-type, optional 

413 The desired data type of output. Defaults to ``float64`` if the 

414 output of `wavelet` is real and ``complex128`` if it is complex. 

415 

416 .. versionadded:: 1.4.0 

417 

418 kwargs 

419 Keyword arguments passed to wavelet function. 

420 

421 .. versionadded:: 1.4.0 

422 

423 Returns 

424 ------- 

425 cwt: (M, N) ndarray 

426 Will have shape of (len(widths), len(data)). 

427 

428 Notes 

429 ----- 

430 

431 .. versionadded:: 1.4.0 

432 

433 For non-symmetric, complex-valued wavelets, the input signal is convolved 

434 with the time-reversed complex-conjugate of the wavelet data [1]. 

435 

436 :: 

437 

438 length = min(10 * width[ii], len(data)) 

439 cwt[ii,:] = signal.convolve(data, np.conj(wavelet(length, width[ii], 

440 **kwargs))[::-1], mode='same') 

441 

442 References 

443 ---------- 

444 .. [1] S. Mallat, "A Wavelet Tour of Signal Processing (3rd Edition)", 

445 Academic Press, 2009. 

446 

447 Examples 

448 -------- 

449 >>> from scipy import signal 

450 >>> import matplotlib.pyplot as plt 

451 >>> t = np.linspace(-1, 1, 200, endpoint=False) 

452 >>> sig = np.cos(2 * np.pi * 7 * t) + signal.gausspulse(t - 0.4, fc=2) 

453 >>> widths = np.arange(1, 31) 

454 >>> cwtmatr = signal.cwt(sig, signal.ricker, widths) 

455 >>> plt.imshow(cwtmatr, extent=[-1, 1, 1, 31], cmap='PRGn', aspect='auto', 

456 ... vmax=abs(cwtmatr).max(), vmin=-abs(cwtmatr).max()) 

457 >>> plt.show() 

458 """ 

459 if wavelet == ricker: 

460 window_size = kwargs.pop('window_size', None) 

461 # Determine output type 

462 if dtype is None: 

463 if np.asarray(wavelet(1, widths[0], **kwargs)).dtype.char in 'FDG': 

464 dtype = np.complex128 

465 else: 

466 dtype = np.float64 

467 

468 output = np.zeros((len(widths), len(data)), dtype=dtype) 

469 for ind, width in enumerate(widths): 

470 N = np.min([10 * width, len(data)]) 

471 # the conditional block below and the window_size 

472 # kwarg pop above may be removed eventually; these 

473 # are shims for 32-bit arch + NumPy <= 1.14.5 to 

474 # address gh-11095 

475 if wavelet == ricker and window_size is None: 

476 ceil = np.ceil(N) 

477 if ceil != N: 

478 N = int(N) 

479 wavelet_data = np.conj(wavelet(N, width, **kwargs)[::-1]) 

480 output[ind] = convolve(data, wavelet_data, mode='same') 

481 return output