Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/matplotlib/streamplot.py : 11%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Streamline plotting for 2D vector fields.
4"""
6import numpy as np
8import matplotlib
9import matplotlib.cbook as cbook
10import matplotlib.cm as cm
11import matplotlib.colors as mcolors
12import matplotlib.collections as mcollections
13import matplotlib.lines as mlines
14import matplotlib.patches as patches
17__all__ = ['streamplot']
20def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
21 cmap=None, norm=None, arrowsize=1, arrowstyle='-|>',
22 minlength=0.1, transform=None, zorder=None, start_points=None,
23 maxlength=4.0, integration_direction='both'):
24 """
25 Draw streamlines of a vector flow.
27 Parameters
28 ----------
29 x, y : 1D arrays
30 An evenly spaced grid.
31 u, v : 2D arrays
32 *x* and *y*-velocities. The number of rows and columns must match
33 the length of *y* and *x*, respectively.
34 density : float or (float, float)
35 Controls the closeness of streamlines. When ``density = 1``, the domain
36 is divided into a 30x30 grid. *density* linearly scales this grid.
37 Each cell in the grid can have, at most, one traversing streamline.
38 For different densities in each direction, use a tuple
39 (density_x, density_y).
40 linewidth : float or 2D array
41 The width of the stream lines. With a 2D array the line width can be
42 varied across the grid. The array must have the same shape as *u*
43 and *v*.
44 color : matplotlib color code, or 2D array
45 The streamline color. If given an array, its values are converted to
46 colors using *cmap* and *norm*. The array must have the same shape
47 as *u* and *v*.
48 cmap : `~matplotlib.colors.Colormap`
49 Colormap used to plot streamlines and arrows. This is only used if
50 *color* is an array.
51 norm : `~matplotlib.colors.Normalize`
52 Normalize object used to scale luminance data to 0, 1. If ``None``,
53 stretch (min, max) to (0, 1). This is only used if *color* is an array.
54 arrowsize : float
55 Scaling factor for the arrow size.
56 arrowstyle : str
57 Arrow style specification.
58 See `~matplotlib.patches.FancyArrowPatch`.
59 minlength : float
60 Minimum length of streamline in axes coordinates.
61 start_points : Nx2 array
62 Coordinates of starting points for the streamlines in data coordinates
63 (the same coordinates as the *x* and *y* arrays).
64 zorder : int
65 The zorder of the stream lines and arrows.
66 Artists with lower zorder values are drawn first.
67 maxlength : float
68 Maximum length of streamline in axes coordinates.
69 integration_direction : {'forward', 'backward', 'both'}
70 Integrate the streamline in forward, backward or both directions.
71 default is ``'both'``.
73 Returns
74 -------
75 stream_container : StreamplotSet
76 Container object with attributes
78 - ``lines``: `.LineCollection` of streamlines
80 - ``arrows``: `.PatchCollection` containing `.FancyArrowPatch`
81 objects representing the arrows half-way along stream lines.
83 This container will probably change in the future to allow changes
84 to the colormap, alpha, etc. for both lines and arrows, but these
85 changes should be backward compatible.
86 """
87 grid = Grid(x, y)
88 mask = StreamMask(density)
89 dmap = DomainMap(grid, mask)
91 if zorder is None:
92 zorder = mlines.Line2D.zorder
94 # default to data coordinates
95 if transform is None:
96 transform = axes.transData
98 if color is None:
99 color = axes._get_lines.get_next_color()
101 if linewidth is None:
102 linewidth = matplotlib.rcParams['lines.linewidth']
104 line_kw = {}
105 arrow_kw = dict(arrowstyle=arrowstyle, mutation_scale=10 * arrowsize)
107 cbook._check_in_list(['both', 'forward', 'backward'],
108 integration_direction=integration_direction)
110 if integration_direction == 'both':
111 maxlength /= 2.
113 use_multicolor_lines = isinstance(color, np.ndarray)
114 if use_multicolor_lines:
115 if color.shape != grid.shape:
116 raise ValueError("If 'color' is given, it must match the shape of "
117 "'Grid(x, y)'")
118 line_colors = []
119 color = np.ma.masked_invalid(color)
120 else:
121 line_kw['color'] = color
122 arrow_kw['color'] = color
124 if isinstance(linewidth, np.ndarray):
125 if linewidth.shape != grid.shape:
126 raise ValueError("If 'linewidth' is given, it must match the "
127 "shape of 'Grid(x, y)'")
128 line_kw['linewidth'] = []
129 else:
130 line_kw['linewidth'] = linewidth
131 arrow_kw['linewidth'] = linewidth
133 line_kw['zorder'] = zorder
134 arrow_kw['zorder'] = zorder
136 # Sanity checks.
137 if u.shape != grid.shape or v.shape != grid.shape:
138 raise ValueError("'u' and 'v' must match the shape of 'Grid(x, y)'")
140 u = np.ma.masked_invalid(u)
141 v = np.ma.masked_invalid(v)
143 integrate = get_integrator(u, v, dmap, minlength, maxlength,
144 integration_direction)
146 trajectories = []
147 if start_points is None:
148 for xm, ym in _gen_starting_points(mask.shape):
149 if mask[ym, xm] == 0:
150 xg, yg = dmap.mask2grid(xm, ym)
151 t = integrate(xg, yg)
152 if t is not None:
153 trajectories.append(t)
154 else:
155 sp2 = np.asanyarray(start_points, dtype=float).copy()
157 # Check if start_points are outside the data boundaries
158 for xs, ys in sp2:
159 if not (grid.x_origin <= xs <= grid.x_origin + grid.width and
160 grid.y_origin <= ys <= grid.y_origin + grid.height):
161 raise ValueError("Starting point ({}, {}) outside of data "
162 "boundaries".format(xs, ys))
164 # Convert start_points from data to array coords
165 # Shift the seed points from the bottom left of the data so that
166 # data2grid works properly.
167 sp2[:, 0] -= grid.x_origin
168 sp2[:, 1] -= grid.y_origin
170 for xs, ys in sp2:
171 xg, yg = dmap.data2grid(xs, ys)
172 t = integrate(xg, yg)
173 if t is not None:
174 trajectories.append(t)
176 if use_multicolor_lines:
177 if norm is None:
178 norm = mcolors.Normalize(color.min(), color.max())
179 if cmap is None:
180 cmap = cm.get_cmap(matplotlib.rcParams['image.cmap'])
181 else:
182 cmap = cm.get_cmap(cmap)
184 streamlines = []
185 arrows = []
186 for t in trajectories:
187 tgx = np.array(t[0])
188 tgy = np.array(t[1])
189 # Rescale from grid-coordinates to data-coordinates.
190 tx, ty = dmap.grid2data(*np.array(t))
191 tx += grid.x_origin
192 ty += grid.y_origin
194 points = np.transpose([tx, ty]).reshape(-1, 1, 2)
195 streamlines.extend(np.hstack([points[:-1], points[1:]]))
197 # Add arrows half way along each trajectory.
198 s = np.cumsum(np.hypot(np.diff(tx), np.diff(ty)))
199 n = np.searchsorted(s, s[-1] / 2.)
200 arrow_tail = (tx[n], ty[n])
201 arrow_head = (np.mean(tx[n:n + 2]), np.mean(ty[n:n + 2]))
203 if isinstance(linewidth, np.ndarray):
204 line_widths = interpgrid(linewidth, tgx, tgy)[:-1]
205 line_kw['linewidth'].extend(line_widths)
206 arrow_kw['linewidth'] = line_widths[n]
208 if use_multicolor_lines:
209 color_values = interpgrid(color, tgx, tgy)[:-1]
210 line_colors.append(color_values)
211 arrow_kw['color'] = cmap(norm(color_values[n]))
213 p = patches.FancyArrowPatch(
214 arrow_tail, arrow_head, transform=transform, **arrow_kw)
215 axes.add_patch(p)
216 arrows.append(p)
218 lc = mcollections.LineCollection(
219 streamlines, transform=transform, **line_kw)
220 lc.sticky_edges.x[:] = [grid.x_origin, grid.x_origin + grid.width]
221 lc.sticky_edges.y[:] = [grid.y_origin, grid.y_origin + grid.height]
222 if use_multicolor_lines:
223 lc.set_array(np.ma.hstack(line_colors))
224 lc.set_cmap(cmap)
225 lc.set_norm(norm)
226 axes.add_collection(lc)
227 axes.autoscale_view()
229 ac = matplotlib.collections.PatchCollection(arrows)
230 stream_container = StreamplotSet(lc, ac)
231 return stream_container
234class StreamplotSet:
236 def __init__(self, lines, arrows, **kwargs):
237 self.lines = lines
238 self.arrows = arrows
241# Coordinate definitions
242# ========================
244class DomainMap:
245 """Map representing different coordinate systems.
247 Coordinate definitions:
249 * axes-coordinates goes from 0 to 1 in the domain.
250 * data-coordinates are specified by the input x-y coordinates.
251 * grid-coordinates goes from 0 to N and 0 to M for an N x M grid,
252 where N and M match the shape of the input data.
253 * mask-coordinates goes from 0 to N and 0 to M for an N x M mask,
254 where N and M are user-specified to control the density of streamlines.
256 This class also has methods for adding trajectories to the StreamMask.
257 Before adding a trajectory, run `start_trajectory` to keep track of regions
258 crossed by a given trajectory. Later, if you decide the trajectory is bad
259 (e.g., if the trajectory is very short) just call `undo_trajectory`.
260 """
262 def __init__(self, grid, mask):
263 self.grid = grid
264 self.mask = mask
265 # Constants for conversion between grid- and mask-coordinates
266 self.x_grid2mask = (mask.nx - 1) / (grid.nx - 1)
267 self.y_grid2mask = (mask.ny - 1) / (grid.ny - 1)
269 self.x_mask2grid = 1. / self.x_grid2mask
270 self.y_mask2grid = 1. / self.y_grid2mask
272 self.x_data2grid = 1. / grid.dx
273 self.y_data2grid = 1. / grid.dy
275 def grid2mask(self, xi, yi):
276 """Return nearest space in mask-coords from given grid-coords."""
277 return (int(xi * self.x_grid2mask + 0.5),
278 int(yi * self.y_grid2mask + 0.5))
280 def mask2grid(self, xm, ym):
281 return xm * self.x_mask2grid, ym * self.y_mask2grid
283 def data2grid(self, xd, yd):
284 return xd * self.x_data2grid, yd * self.y_data2grid
286 def grid2data(self, xg, yg):
287 return xg / self.x_data2grid, yg / self.y_data2grid
289 def start_trajectory(self, xg, yg):
290 xm, ym = self.grid2mask(xg, yg)
291 self.mask._start_trajectory(xm, ym)
293 def reset_start_point(self, xg, yg):
294 xm, ym = self.grid2mask(xg, yg)
295 self.mask._current_xy = (xm, ym)
297 def update_trajectory(self, xg, yg):
298 if not self.grid.within_grid(xg, yg):
299 raise InvalidIndexError
300 xm, ym = self.grid2mask(xg, yg)
301 self.mask._update_trajectory(xm, ym)
303 def undo_trajectory(self):
304 self.mask._undo_trajectory()
307class Grid:
308 """Grid of data."""
309 def __init__(self, x, y):
311 if x.ndim == 1:
312 pass
313 elif x.ndim == 2:
314 x_row = x[0, :]
315 if not np.allclose(x_row, x):
316 raise ValueError("The rows of 'x' must be equal")
317 x = x_row
318 else:
319 raise ValueError("'x' can have at maximum 2 dimensions")
321 if y.ndim == 1:
322 pass
323 elif y.ndim == 2:
324 y_col = y[:, 0]
325 if not np.allclose(y_col, y.T):
326 raise ValueError("The columns of 'y' must be equal")
327 y = y_col
328 else:
329 raise ValueError("'y' can have at maximum 2 dimensions")
331 self.nx = len(x)
332 self.ny = len(y)
334 self.dx = x[1] - x[0]
335 self.dy = y[1] - y[0]
337 self.x_origin = x[0]
338 self.y_origin = y[0]
340 self.width = x[-1] - x[0]
341 self.height = y[-1] - y[0]
343 if not np.allclose(np.diff(x), self.width / (self.nx - 1)):
344 raise ValueError("'x' values must be equally spaced")
345 if not np.allclose(np.diff(y), self.height / (self.ny - 1)):
346 raise ValueError("'y' values must be equally spaced")
348 @property
349 def shape(self):
350 return self.ny, self.nx
352 def within_grid(self, xi, yi):
353 """Return True if point is a valid index of grid."""
354 # Note that xi/yi can be floats; so, for example, we can't simply check
355 # `xi < self.nx` since *xi* can be `self.nx - 1 < xi < self.nx`
356 return xi >= 0 and xi <= self.nx - 1 and yi >= 0 and yi <= self.ny - 1
359class StreamMask:
360 """Mask to keep track of discrete regions crossed by streamlines.
362 The resolution of this grid determines the approximate spacing between
363 trajectories. Streamlines are only allowed to pass through zeroed cells:
364 When a streamline enters a cell, that cell is set to 1, and no new
365 streamlines are allowed to enter.
366 """
368 def __init__(self, density):
369 try:
370 self.nx, self.ny = (30 * np.broadcast_to(density, 2)).astype(int)
371 except ValueError:
372 raise ValueError("'density' must be a scalar or be of length 2")
373 if self.nx < 0 or self.ny < 0:
374 raise ValueError("'density' must be positive")
375 self._mask = np.zeros((self.ny, self.nx))
376 self.shape = self._mask.shape
378 self._current_xy = None
380 def __getitem__(self, *args):
381 return self._mask.__getitem__(*args)
383 def _start_trajectory(self, xm, ym):
384 """Start recording streamline trajectory"""
385 self._traj = []
386 self._update_trajectory(xm, ym)
388 def _undo_trajectory(self):
389 """Remove current trajectory from mask"""
390 for t in self._traj:
391 self._mask.__setitem__(t, 0)
393 def _update_trajectory(self, xm, ym):
394 """Update current trajectory position in mask.
396 If the new position has already been filled, raise `InvalidIndexError`.
397 """
398 if self._current_xy != (xm, ym):
399 if self[ym, xm] == 0:
400 self._traj.append((ym, xm))
401 self._mask[ym, xm] = 1
402 self._current_xy = (xm, ym)
403 else:
404 raise InvalidIndexError
407class InvalidIndexError(Exception):
408 pass
411class TerminateTrajectory(Exception):
412 pass
415# Integrator definitions
416# =======================
418def get_integrator(u, v, dmap, minlength, maxlength, integration_direction):
420 # rescale velocity onto grid-coordinates for integrations.
421 u, v = dmap.data2grid(u, v)
423 # speed (path length) will be in axes-coordinates
424 u_ax = u / (dmap.grid.nx - 1)
425 v_ax = v / (dmap.grid.ny - 1)
426 speed = np.ma.sqrt(u_ax ** 2 + v_ax ** 2)
428 def forward_time(xi, yi):
429 if not dmap.grid.within_grid(xi, yi):
430 raise OutOfBounds
431 ds_dt = interpgrid(speed, xi, yi)
432 if ds_dt == 0:
433 raise TerminateTrajectory()
434 dt_ds = 1. / ds_dt
435 ui = interpgrid(u, xi, yi)
436 vi = interpgrid(v, xi, yi)
437 return ui * dt_ds, vi * dt_ds
439 def backward_time(xi, yi):
440 dxi, dyi = forward_time(xi, yi)
441 return -dxi, -dyi
443 def integrate(x0, y0):
444 """Return x, y grid-coordinates of trajectory based on starting point.
446 Integrate both forward and backward in time from starting point in
447 grid coordinates.
449 Integration is terminated when a trajectory reaches a domain boundary
450 or when it crosses into an already occupied cell in the StreamMask. The
451 resulting trajectory is None if it is shorter than `minlength`.
452 """
454 stotal, x_traj, y_traj = 0., [], []
456 try:
457 dmap.start_trajectory(x0, y0)
458 except InvalidIndexError:
459 return None
460 if integration_direction in ['both', 'backward']:
461 s, xt, yt = _integrate_rk12(x0, y0, dmap, backward_time, maxlength)
462 stotal += s
463 x_traj += xt[::-1]
464 y_traj += yt[::-1]
466 if integration_direction in ['both', 'forward']:
467 dmap.reset_start_point(x0, y0)
468 s, xt, yt = _integrate_rk12(x0, y0, dmap, forward_time, maxlength)
469 if len(x_traj) > 0:
470 xt = xt[1:]
471 yt = yt[1:]
472 stotal += s
473 x_traj += xt
474 y_traj += yt
476 if stotal > minlength:
477 return x_traj, y_traj
478 else: # reject short trajectories
479 dmap.undo_trajectory()
480 return None
482 return integrate
485class OutOfBounds(IndexError):
486 pass
489def _integrate_rk12(x0, y0, dmap, f, maxlength):
490 """2nd-order Runge-Kutta algorithm with adaptive step size.
492 This method is also referred to as the improved Euler's method, or Heun's
493 method. This method is favored over higher-order methods because:
495 1. To get decent looking trajectories and to sample every mask cell
496 on the trajectory we need a small timestep, so a lower order
497 solver doesn't hurt us unless the data is *very* high resolution.
498 In fact, for cases where the user inputs
499 data smaller or of similar grid size to the mask grid, the higher
500 order corrections are negligible because of the very fast linear
501 interpolation used in `interpgrid`.
503 2. For high resolution input data (i.e. beyond the mask
504 resolution), we must reduce the timestep. Therefore, an adaptive
505 timestep is more suited to the problem as this would be very hard
506 to judge automatically otherwise.
508 This integrator is about 1.5 - 2x as fast as both the RK4 and RK45
509 solvers in most setups on my machine. I would recommend removing the
510 other two to keep things simple.
511 """
512 # This error is below that needed to match the RK4 integrator. It
513 # is set for visual reasons -- too low and corners start
514 # appearing ugly and jagged. Can be tuned.
515 maxerror = 0.003
517 # This limit is important (for all integrators) to avoid the
518 # trajectory skipping some mask cells. We could relax this
519 # condition if we use the code which is commented out below to
520 # increment the location gradually. However, due to the efficient
521 # nature of the interpolation, this doesn't boost speed by much
522 # for quite a bit of complexity.
523 maxds = min(1. / dmap.mask.nx, 1. / dmap.mask.ny, 0.1)
525 ds = maxds
526 stotal = 0
527 xi = x0
528 yi = y0
529 xf_traj = []
530 yf_traj = []
532 while True:
533 try:
534 if dmap.grid.within_grid(xi, yi):
535 xf_traj.append(xi)
536 yf_traj.append(yi)
537 else:
538 raise OutOfBounds
540 # Compute the two intermediate gradients.
541 # f should raise OutOfBounds if the locations given are
542 # outside the grid.
543 k1x, k1y = f(xi, yi)
544 k2x, k2y = f(xi + ds * k1x, yi + ds * k1y)
546 except OutOfBounds:
547 # Out of the domain during this step.
548 # Take an Euler step to the boundary to improve neatness
549 # unless the trajectory is currently empty.
550 if xf_traj:
551 ds, xf_traj, yf_traj = _euler_step(xf_traj, yf_traj,
552 dmap, f)
553 stotal += ds
554 break
555 except TerminateTrajectory:
556 break
558 dx1 = ds * k1x
559 dy1 = ds * k1y
560 dx2 = ds * 0.5 * (k1x + k2x)
561 dy2 = ds * 0.5 * (k1y + k2y)
563 nx, ny = dmap.grid.shape
564 # Error is normalized to the axes coordinates
565 error = np.hypot((dx2 - dx1) / (nx - 1), (dy2 - dy1) / (ny - 1))
567 # Only save step if within error tolerance
568 if error < maxerror:
569 xi += dx2
570 yi += dy2
571 try:
572 dmap.update_trajectory(xi, yi)
573 except InvalidIndexError:
574 break
575 if stotal + ds > maxlength:
576 break
577 stotal += ds
579 # recalculate stepsize based on step error
580 if error == 0:
581 ds = maxds
582 else:
583 ds = min(maxds, 0.85 * ds * (maxerror / error) ** 0.5)
585 return stotal, xf_traj, yf_traj
588def _euler_step(xf_traj, yf_traj, dmap, f):
589 """Simple Euler integration step that extends streamline to boundary."""
590 ny, nx = dmap.grid.shape
591 xi = xf_traj[-1]
592 yi = yf_traj[-1]
593 cx, cy = f(xi, yi)
594 if cx == 0:
595 dsx = np.inf
596 elif cx < 0:
597 dsx = xi / -cx
598 else:
599 dsx = (nx - 1 - xi) / cx
600 if cy == 0:
601 dsy = np.inf
602 elif cy < 0:
603 dsy = yi / -cy
604 else:
605 dsy = (ny - 1 - yi) / cy
606 ds = min(dsx, dsy)
607 xf_traj.append(xi + cx * ds)
608 yf_traj.append(yi + cy * ds)
609 return ds, xf_traj, yf_traj
612# Utility functions
613# ========================
615def interpgrid(a, xi, yi):
616 """Fast 2D, linear interpolation on an integer grid"""
618 Ny, Nx = np.shape(a)
619 if isinstance(xi, np.ndarray):
620 x = xi.astype(int)
621 y = yi.astype(int)
622 # Check that xn, yn don't exceed max index
623 xn = np.clip(x + 1, 0, Nx - 1)
624 yn = np.clip(y + 1, 0, Ny - 1)
625 else:
626 x = int(xi)
627 y = int(yi)
628 # conditional is faster than clipping for integers
629 if x == (Nx - 1):
630 xn = x
631 else:
632 xn = x + 1
633 if y == (Ny - 1):
634 yn = y
635 else:
636 yn = y + 1
638 a00 = a[y, x]
639 a01 = a[y, xn]
640 a10 = a[yn, x]
641 a11 = a[yn, xn]
642 xt = xi - x
643 yt = yi - y
644 a0 = a00 * (1 - xt) + a01 * xt
645 a1 = a10 * (1 - xt) + a11 * xt
646 ai = a0 * (1 - yt) + a1 * yt
648 if not isinstance(xi, np.ndarray):
649 if np.ma.is_masked(ai):
650 raise TerminateTrajectory
652 return ai
655def _gen_starting_points(shape):
656 """Yield starting points for streamlines.
658 Trying points on the boundary first gives higher quality streamlines.
659 This algorithm starts with a point on the mask corner and spirals inward.
660 This algorithm is inefficient, but fast compared to rest of streamplot.
661 """
662 ny, nx = shape
663 xfirst = 0
664 yfirst = 1
665 xlast = nx - 1
666 ylast = ny - 1
667 x, y = 0, 0
668 direction = 'right'
669 for i in range(nx * ny):
670 yield x, y
672 if direction == 'right':
673 x += 1
674 if x >= xlast:
675 xlast -= 1
676 direction = 'up'
677 elif direction == 'up':
678 y += 1
679 if y >= ylast:
680 ylast -= 1
681 direction = 'left'
682 elif direction == 'left':
683 x -= 1
684 if x <= xfirst:
685 xfirst += 1
686 direction = 'down'
687 elif direction == 'down':
688 y -= 1
689 if y <= yfirst:
690 yfirst += 1
691 direction = 'right'