Coverage for /Users/Newville/Codes/xraylarch/larch/qtlib/plot2D.py: 0%

172 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-09 10:08 -0600

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3""" 

4Custom version of SILX Plot2D 

5============================= 

6""" 

7import time 

8import numpy as np 

9from silx.gui.plot import Plot2D as silxPlot2D 

10from larch.utils.logging import getLogger 

11 

12 

13class Plot2D(silxPlot2D): 

14 """Custom Plot2D instance targeted to 2D images""" 

15 

16 def __init__(self, parent=None, backend=None, logger=None, title="Plot2D"): 

17 

18 super(Plot2D, self).__init__(parent=parent, backend=backend) 

19 

20 self._logger = logger or getLogger("Plot2D") 

21 self._index = None 

22 self._title = title 

23 self.setWindowTitle(self._title) 

24 self._image = None 

25 self._mask = None 

26 self._origin = (0, 0) 

27 self._scale = (1, 1) 

28 self._xlabel = 'X' 

29 self._ylabel = 'Y' 

30 self.setKeepDataAspectRatio(True) 

31 self.getDefaultColormap().setName('viridis') 

32 

33 def _drawContours(self, values, color='gray', 

34 plot_timeout=100, plot_method='curve'): 

35 """Draw iso contours for given values 

36 

37 Parameters 

38 ---------- 

39 values : list or array 

40 intensities at which to find contours 

41 color : string (optional) 

42 color of contours (among common color names) ['gray'] 

43 plot_timeout : int (optional) 

44 time in seconds befor the plot is interrupted 

45 plot_method : str (optional) 

46 method to use for the contour plot 

47 'curve' -> self.addCurve, polygons as from find_contours 

48 'curve_max' -> self.addCurve, one polygon (max length) 

49 'curve_merge' -> self.addCurve, one polygon (concatenate) 

50 'scatter' -> self.addScatter (only points) 

51 """ 

52 if self._ms is None: 

53 return 

54 ipolygon = 0 

55 totTime = 0 

56 for ivalue, value in enumerate(values): 

57 startTime = time.time() 

58 polygons = self._ms.find_contours(value) 

59 polTime = time.time() 

60 self._logger.debug(f"Found {len(polygons)} polygon at level {value}") 

61 totTime += polTime - startTime 

62 # prepare polygons list for plot_method 

63 if len(polygons) == 0: 

64 continue 

65 if len(polygons) > 1: 

66 if (plot_method == 'curve_max'): 

67 from sloth.utils.arrays import imax 

68 lengths = [len(x) for x in polygons] 

69 polygons = [polygons[imax(lengths)]] 

70 elif (plot_method == 'curve_merge') or (plot_method == 'scatter'): 

71 polygons = [np.concatenate(polygons, axis=0)] 

72 else: 

73 pass 

74 # define default contour style 

75 contourStyle = {"linestyle": "-", 

76 "linewidth": 0.5, 

77 "color": color} 

78 for polygon in polygons: 

79 legend = "polygon-%d" % ipolygon 

80 xpoly = polygon[:, 1] 

81 ypoly = polygon[:, 0] 

82 xscale = np.ones_like(xpoly) * self._scale[0] 

83 yscale = np.ones_like(ypoly) * self._scale[1] 

84 xorigin = np.ones_like(xpoly) * self._origin[0] 

85 yorigin = np.ones_like(ypoly) * self._origin[1] 

86 x = xpoly * xscale + xorigin 

87 y = ypoly * yscale + yorigin 

88 # plot timeout 

89 if totTime >= plot_timeout: 

90 self._logger.warning("Plot contours time out reached!") 

91 break 

92 # plot methods 

93 if plot_method == 'scatter': 

94 from silx.gui.colors import (Colormap, rgba) 

95 cm = Colormap() 

96 cm.setColormapLUT([rgba(color)]) 

97 arrval = np.ones_like(x)*value 

98 self.addScatter(x, y, arrval, symbol='.', colormap=cm) 

99 else: 

100 self.addCurve(x=x, y=y, legend=legend, resetzoom=False, 

101 **contourStyle) 

102 pltTime = time.time() 

103 totTime += pltTime - polTime 

104 ipolygon += 1 

105 

106 def addContours(self, nlevels, algo='merge', **draw_kwars): 

107 """Add contour lines to plot 

108 

109 Parameters 

110 ---------- 

111 nlevels : int 

112 number of contour levels to plot 

113 

114 algo : str (optional) 

115 marching squares algorithm implementation 

116 'merge' -> silx 

117 'skimage' -> scikit-image 

118 color : str, optional 

119 color of contour lines ['gray'] 

120 linestyle : str, optional 

121 line style of contour lines ['-'] 

122 linewidth : int, optional 

123 line width of contour lines [1] 

124 

125 Returns 

126 ------- 

127 None 

128 """ 

129 image = self._image 

130 mask = self._mask 

131 if image is None: 

132 self._logger.error('add image first!') 

133 if algo == 'merge': 

134 from silx.image.marchingsquares._mergeimpl import MarchingSquaresMergeImpl 

135 self._ms = MarchingSquaresMergeImpl(image, mask=mask) 

136 elif algo == 'skimage': 

137 try: 

138 import skimage 

139 from silx.image.marchingsquares._skimage import MarchingSquaresSciKitImage 

140 self._ms = MarchingSquaresSciKitImage(image, 

141 mask=mask) 

142 except ImportError: 

143 self._logger.error('skimage not found') 

144 self._ms = None 

145 else: 

146 self._ms = None 

147 imgmin, imgmax = image.min(), image.max() 

148 delta = (imgmax - imgmin) / nlevels 

149 values = np.arange(imgmin, imgmax, delta) 

150 self._drawContours(values, **draw_kwars) 

151 

152 def index(self): 

153 if self._index is None: 

154 self._index = 0 

155 return self._index 

156 

157 def setIndex(self, value): 

158 self._index = value 

159 if self._index is not None: 

160 self.setWindowTitle('{0}: {1}'.format(self._index, self._title)) 

161 

162 def reset(self): 

163 self.clear() 

164 self.setGraphTitle() 

165 self.setGraphXLabel('X') 

166 # self.setGraphXLimits(0, 100) 

167 self.setGraphYLabel('Y') 

168 # self.setGraphYLimits(0, 100) 

169 

170 def addImage(self, data, x=None, y=None, 

171 title=None, xlabel=None, ylabel=None, 

172 vmin=None, vmax=None, **kwargs): 

173 """Custom addImage 

174 

175 Parameters 

176 ---------- 

177 data : array 

178 x, y : None or array (optional) 

179 x, y to set origin and scale (both should be given!) 

180 title : str 

181 set self.setGraphTitle(str) / self.setWindowTitle(str) 

182 xlabel, ylabel : None or str (optional) 

183 set self.setGraphXLabel / self.setGraphYLabel 

184 vmin, vmax : float (optional) 

185 intensity values of the colormap min/max 

186 """ 

187 self._image = data 

188 self._x = x 

189 self._y = y 

190 if (x is not None) and (y is not None): 

191 self._origin = (np.min(x), np.min(y)) 

192 self._scale = (x[1]-x[0], y[1]-y[0]) 

193 if title is not None: 

194 self._title = title 

195 self.setGraphTitle(title) 

196 if self._index is not None: 

197 self.setWindowTitle('{0}: {1}'.format(self._index, self._title)) 

198 else: 

199 self.setWindowTitle(self._title) 

200 if xlabel is not None: 

201 self._xlabel = xlabel 

202 self.setGraphXLabel(xlabel) 

203 if ylabel is not None: 

204 self._ylabel = ylabel 

205 self.setGraphYLabel(ylabel) 

206 if (vmin is None): 

207 vmin = self._image.min() 

208 if (vmax is None): 

209 vmax = self._image.max() 

210 self.getDefaultColormap().setVRange(vmin, vmax) 

211 return super(Plot2D, self).addImage(data, origin=self._origin, 

212 scale=self._scale, 

213 **kwargs) 

214 

215 

216def dummy_gauss_image(x=None, y=None, 

217 xhalfrng=1.5, yhalfrng=None, xcen=0.5, ycen=0.9, 

218 xnpts=1024, ynpts=None, xsigma=0.55, ysigma=0.25, 

219 noise=0.3): 

220 """Create a dummy 2D Gaussian image with noise 

221 

222 Parameters 

223 ---------- 

224 x, y : 1D arrays (optional) 

225 arrays where to generate the image [None -> generated] 

226 xhalfrng : float (optional) 

227 half range of the X axis [1.5] 

228 yhalfrng : float or None (optional) 

229 half range of the Y axis [None -> xhalfrng] 

230 xcen : float (optional) 

231 X center [0.5] 

232 ycen : float (optional) 

233 Y center [0.9] 

234 xnpts : int (optional) 

235 number of points X [1024] 

236 ynpts : int or None (optional) 

237 number of points Y [None -> xnpts] 

238 xsigma : float (optional) 

239 sigma X [0.55] 

240 ysigma : float (optional) 

241 sigma Y [0.25] 

242 noise : float (optional) 

243 random noise level between 0 and 1 [0.3] 

244 

245 Returns 

246 ------- 

247 x, y : 1D arrays 

248 signal : 2D array 

249 """ 

250 if yhalfrng is None: 

251 yhalfrng = xhalfrng 

252 if ycen is None: 

253 ycen = xcen 

254 if ynpts is None: 

255 ynpts = xnpts 

256 if x is None: 

257 x = np.linspace(xcen-xhalfrng, xcen+xhalfrng, xnpts) 

258 if y is None: 

259 y = np.linspace(ycen-yhalfrng, ycen+yhalfrng, ynpts) 

260 xx, yy = np.meshgrid(x, y) 

261 signal = np.exp(-((xx-xcen)**2 / (2*xsigma**2) + 

262 ((yy-ycen)**2 / (2*ysigma**2)))) 

263 # add noise 

264 signal += noise * np.random.random(size=signal.shape) 

265 return x, y, signal 

266 

267 

268def main(contour_levels=5, noise=0.1, compare_with_matplolib=False, 

269 plot_method='curve'): 

270 """Run a Qt app with the widget""" 

271 from silx import sx 

272 sx.enable_gui() 

273 xhalfrng = 10.5 

274 yhalfrng = 5.5 

275 npts = 1024 

276 xcen = 0 

277 ycen = 0 

278 x = np.linspace(xcen-0.7*xhalfrng, xcen+1.3*xhalfrng, npts) 

279 y = np.linspace(ycen-0.7*yhalfrng, ycen+1.3*yhalfrng, npts) 

280 x1, y1, signal1 = dummy_gauss_image(x=x, y=y, xcen=xcen, ycen=ycen, 

281 xsigma=3, ysigma=1.1, 

282 noise=noise) 

283 x2, y2, signal2 = dummy_gauss_image(x=x, y=y, 

284 xcen=4.2, ycen=2.2, 

285 xsigma=3, ysigma=2.1, 

286 noise=noise) 

287 signal = signal1 + 0.8*signal2 

288 p = Plot2D(backend='matplotlib') 

289 p.addImage(signal, x=x, y=y, xlabel='X', ylabel='Y') 

290 p.addContours(contour_levels, plot_method=plot_method) 

291 p.show() 

292 

293 if compare_with_matplolib: 

294 import matplotlib.pyplot as plt 

295 from matplotlib import cm 

296 plt.ion() 

297 plt.close('all') 

298 fig, ax = plt.subplots() 

299 imgMin, imgMax = np.min(signal), np.max(signal) 

300 values = np.linspace(imgMin, imgMax, contour_levels) 

301 extent = (x.min(), x.max(), y.min(), y.max()) 

302 ax.imshow(signal, origin='lower', extent=extent, 

303 cmap=cm.viridis) 

304 ax.contour(x, y, signal, values, origin='lower', extent=extent, 

305 colors='gray', linewidths=1) 

306 ax.set_title("pure matplotlib") 

307 plt.show() 

308 

309 input("Press enter to close window") 

310 

311 

312if __name__ == '__main__': 

313 main()