Coverage for /Users/Newville/Codes/xraylarch/larch/qtlib/plot2D.py: 0%
172 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Custom version of SILX Plot2D
5=============================
6"""
7import time
8import numpy as np
9from silx.gui.plot import Plot2D as silxPlot2D
10from larch.utils.logging import getLogger
13class Plot2D(silxPlot2D):
14 """Custom Plot2D instance targeted to 2D images"""
16 def __init__(self, parent=None, backend=None, logger=None, title="Plot2D"):
18 super(Plot2D, self).__init__(parent=parent, backend=backend)
20 self._logger = logger or getLogger("Plot2D")
21 self._index = None
22 self._title = title
23 self.setWindowTitle(self._title)
24 self._image = None
25 self._mask = None
26 self._origin = (0, 0)
27 self._scale = (1, 1)
28 self._xlabel = 'X'
29 self._ylabel = 'Y'
30 self.setKeepDataAspectRatio(True)
31 self.getDefaultColormap().setName('viridis')
33 def _drawContours(self, values, color='gray',
34 plot_timeout=100, plot_method='curve'):
35 """Draw iso contours for given values
37 Parameters
38 ----------
39 values : list or array
40 intensities at which to find contours
41 color : string (optional)
42 color of contours (among common color names) ['gray']
43 plot_timeout : int (optional)
44 time in seconds befor the plot is interrupted
45 plot_method : str (optional)
46 method to use for the contour plot
47 'curve' -> self.addCurve, polygons as from find_contours
48 'curve_max' -> self.addCurve, one polygon (max length)
49 'curve_merge' -> self.addCurve, one polygon (concatenate)
50 'scatter' -> self.addScatter (only points)
51 """
52 if self._ms is None:
53 return
54 ipolygon = 0
55 totTime = 0
56 for ivalue, value in enumerate(values):
57 startTime = time.time()
58 polygons = self._ms.find_contours(value)
59 polTime = time.time()
60 self._logger.debug(f"Found {len(polygons)} polygon at level {value}")
61 totTime += polTime - startTime
62 # prepare polygons list for plot_method
63 if len(polygons) == 0:
64 continue
65 if len(polygons) > 1:
66 if (plot_method == 'curve_max'):
67 from sloth.utils.arrays import imax
68 lengths = [len(x) for x in polygons]
69 polygons = [polygons[imax(lengths)]]
70 elif (plot_method == 'curve_merge') or (plot_method == 'scatter'):
71 polygons = [np.concatenate(polygons, axis=0)]
72 else:
73 pass
74 # define default contour style
75 contourStyle = {"linestyle": "-",
76 "linewidth": 0.5,
77 "color": color}
78 for polygon in polygons:
79 legend = "polygon-%d" % ipolygon
80 xpoly = polygon[:, 1]
81 ypoly = polygon[:, 0]
82 xscale = np.ones_like(xpoly) * self._scale[0]
83 yscale = np.ones_like(ypoly) * self._scale[1]
84 xorigin = np.ones_like(xpoly) * self._origin[0]
85 yorigin = np.ones_like(ypoly) * self._origin[1]
86 x = xpoly * xscale + xorigin
87 y = ypoly * yscale + yorigin
88 # plot timeout
89 if totTime >= plot_timeout:
90 self._logger.warning("Plot contours time out reached!")
91 break
92 # plot methods
93 if plot_method == 'scatter':
94 from silx.gui.colors import (Colormap, rgba)
95 cm = Colormap()
96 cm.setColormapLUT([rgba(color)])
97 arrval = np.ones_like(x)*value
98 self.addScatter(x, y, arrval, symbol='.', colormap=cm)
99 else:
100 self.addCurve(x=x, y=y, legend=legend, resetzoom=False,
101 **contourStyle)
102 pltTime = time.time()
103 totTime += pltTime - polTime
104 ipolygon += 1
106 def addContours(self, nlevels, algo='merge', **draw_kwars):
107 """Add contour lines to plot
109 Parameters
110 ----------
111 nlevels : int
112 number of contour levels to plot
114 algo : str (optional)
115 marching squares algorithm implementation
116 'merge' -> silx
117 'skimage' -> scikit-image
118 color : str, optional
119 color of contour lines ['gray']
120 linestyle : str, optional
121 line style of contour lines ['-']
122 linewidth : int, optional
123 line width of contour lines [1]
125 Returns
126 -------
127 None
128 """
129 image = self._image
130 mask = self._mask
131 if image is None:
132 self._logger.error('add image first!')
133 if algo == 'merge':
134 from silx.image.marchingsquares._mergeimpl import MarchingSquaresMergeImpl
135 self._ms = MarchingSquaresMergeImpl(image, mask=mask)
136 elif algo == 'skimage':
137 try:
138 import skimage
139 from silx.image.marchingsquares._skimage import MarchingSquaresSciKitImage
140 self._ms = MarchingSquaresSciKitImage(image,
141 mask=mask)
142 except ImportError:
143 self._logger.error('skimage not found')
144 self._ms = None
145 else:
146 self._ms = None
147 imgmin, imgmax = image.min(), image.max()
148 delta = (imgmax - imgmin) / nlevels
149 values = np.arange(imgmin, imgmax, delta)
150 self._drawContours(values, **draw_kwars)
152 def index(self):
153 if self._index is None:
154 self._index = 0
155 return self._index
157 def setIndex(self, value):
158 self._index = value
159 if self._index is not None:
160 self.setWindowTitle('{0}: {1}'.format(self._index, self._title))
162 def reset(self):
163 self.clear()
164 self.setGraphTitle()
165 self.setGraphXLabel('X')
166 # self.setGraphXLimits(0, 100)
167 self.setGraphYLabel('Y')
168 # self.setGraphYLimits(0, 100)
170 def addImage(self, data, x=None, y=None,
171 title=None, xlabel=None, ylabel=None,
172 vmin=None, vmax=None, **kwargs):
173 """Custom addImage
175 Parameters
176 ----------
177 data : array
178 x, y : None or array (optional)
179 x, y to set origin and scale (both should be given!)
180 title : str
181 set self.setGraphTitle(str) / self.setWindowTitle(str)
182 xlabel, ylabel : None or str (optional)
183 set self.setGraphXLabel / self.setGraphYLabel
184 vmin, vmax : float (optional)
185 intensity values of the colormap min/max
186 """
187 self._image = data
188 self._x = x
189 self._y = y
190 if (x is not None) and (y is not None):
191 self._origin = (np.min(x), np.min(y))
192 self._scale = (x[1]-x[0], y[1]-y[0])
193 if title is not None:
194 self._title = title
195 self.setGraphTitle(title)
196 if self._index is not None:
197 self.setWindowTitle('{0}: {1}'.format(self._index, self._title))
198 else:
199 self.setWindowTitle(self._title)
200 if xlabel is not None:
201 self._xlabel = xlabel
202 self.setGraphXLabel(xlabel)
203 if ylabel is not None:
204 self._ylabel = ylabel
205 self.setGraphYLabel(ylabel)
206 if (vmin is None):
207 vmin = self._image.min()
208 if (vmax is None):
209 vmax = self._image.max()
210 self.getDefaultColormap().setVRange(vmin, vmax)
211 return super(Plot2D, self).addImage(data, origin=self._origin,
212 scale=self._scale,
213 **kwargs)
216def dummy_gauss_image(x=None, y=None,
217 xhalfrng=1.5, yhalfrng=None, xcen=0.5, ycen=0.9,
218 xnpts=1024, ynpts=None, xsigma=0.55, ysigma=0.25,
219 noise=0.3):
220 """Create a dummy 2D Gaussian image with noise
222 Parameters
223 ----------
224 x, y : 1D arrays (optional)
225 arrays where to generate the image [None -> generated]
226 xhalfrng : float (optional)
227 half range of the X axis [1.5]
228 yhalfrng : float or None (optional)
229 half range of the Y axis [None -> xhalfrng]
230 xcen : float (optional)
231 X center [0.5]
232 ycen : float (optional)
233 Y center [0.9]
234 xnpts : int (optional)
235 number of points X [1024]
236 ynpts : int or None (optional)
237 number of points Y [None -> xnpts]
238 xsigma : float (optional)
239 sigma X [0.55]
240 ysigma : float (optional)
241 sigma Y [0.25]
242 noise : float (optional)
243 random noise level between 0 and 1 [0.3]
245 Returns
246 -------
247 x, y : 1D arrays
248 signal : 2D array
249 """
250 if yhalfrng is None:
251 yhalfrng = xhalfrng
252 if ycen is None:
253 ycen = xcen
254 if ynpts is None:
255 ynpts = xnpts
256 if x is None:
257 x = np.linspace(xcen-xhalfrng, xcen+xhalfrng, xnpts)
258 if y is None:
259 y = np.linspace(ycen-yhalfrng, ycen+yhalfrng, ynpts)
260 xx, yy = np.meshgrid(x, y)
261 signal = np.exp(-((xx-xcen)**2 / (2*xsigma**2) +
262 ((yy-ycen)**2 / (2*ysigma**2))))
263 # add noise
264 signal += noise * np.random.random(size=signal.shape)
265 return x, y, signal
268def main(contour_levels=5, noise=0.1, compare_with_matplolib=False,
269 plot_method='curve'):
270 """Run a Qt app with the widget"""
271 from silx import sx
272 sx.enable_gui()
273 xhalfrng = 10.5
274 yhalfrng = 5.5
275 npts = 1024
276 xcen = 0
277 ycen = 0
278 x = np.linspace(xcen-0.7*xhalfrng, xcen+1.3*xhalfrng, npts)
279 y = np.linspace(ycen-0.7*yhalfrng, ycen+1.3*yhalfrng, npts)
280 x1, y1, signal1 = dummy_gauss_image(x=x, y=y, xcen=xcen, ycen=ycen,
281 xsigma=3, ysigma=1.1,
282 noise=noise)
283 x2, y2, signal2 = dummy_gauss_image(x=x, y=y,
284 xcen=4.2, ycen=2.2,
285 xsigma=3, ysigma=2.1,
286 noise=noise)
287 signal = signal1 + 0.8*signal2
288 p = Plot2D(backend='matplotlib')
289 p.addImage(signal, x=x, y=y, xlabel='X', ylabel='Y')
290 p.addContours(contour_levels, plot_method=plot_method)
291 p.show()
293 if compare_with_matplolib:
294 import matplotlib.pyplot as plt
295 from matplotlib import cm
296 plt.ion()
297 plt.close('all')
298 fig, ax = plt.subplots()
299 imgMin, imgMax = np.min(signal), np.max(signal)
300 values = np.linspace(imgMin, imgMax, contour_levels)
301 extent = (x.min(), x.max(), y.min(), y.max())
302 ax.imshow(signal, origin='lower', extent=extent,
303 cmap=cm.viridis)
304 ax.contour(x, y, signal, values, origin='lower', extent=extent,
305 colors='gray', linewidths=1)
306 ax.set_title("pure matplotlib")
307 plt.show()
309 input("Press enter to close window")
312if __name__ == '__main__':
313 main()