Coverage for C:\src\imod-python\imod\prepare\regrid.py: 95%
253 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-08 13:27 +0200
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-08 13:27 +0200
1"""
2Module that provides a class to do a variety of regridding operations, up to
3three dimensions.
5Before regridding, the dimension over which regridding should occur are
6inferred, using the functions in the imod.prepare.common module. In case
7multiple dimensions are represent, the data is reshaped such that a single loop
8will regrid them all.
10For example: let there be a DataArray with dimensions time, layer, y, and x. We
11wish to regrid using an area weighted mean, over x and y. This means values
12across times and layers are not aggregated together. In this case, the array is
13reshaped into a 3D array, rather than a 4D array. Time and layer are stacked
14into this first dimension together, so that a single loop suffices (see
15common._reshape and _iter_regrid).
17Functions can be incorporated into the multidimensional regridding. This is done
18by making use of numba closures, since there's an overhead to passing function
19objects directly. In this case, the function is simply compiled into the
20specific regridding method, without additional overhead.
22The regrid methods _regrid_{n}d are quite straightfoward. Using the indices that
23and weights that have been gathered by _weights_1d, these methods fetch the
24values from the source array (src), and pass it on to the aggregation method.
25The single aggregated value is then filled into the destination array (dst).
26"""
28from collections import namedtuple
30import dask
31import numba
32import numpy as np
33import xarray as xr
35from imod.prepare import common, interpolate
37_RegridInfo = namedtuple(
38 typename="_RegridInfo",
39 field_names=[
40 "matching_dims",
41 "regrid_dims",
42 "add_dims",
43 "dst_shape",
44 "dst_dims",
45 "dst_da_coords",
46 "src_coords_regrid",
47 "dst_coords_regrid",
48 ],
49)
52@numba.njit(cache=True)
53def _regrid_1d(src, dst, values, weights, method, *inds_weights):
54 """
55 numba compiled function to regrid in three dimensions
57 Parameters
58 ----------
59 src : np.array
60 dst : np.array
61 src_coords : tuple of np.arrays of edges
62 dst_coords : tuple of np.arrays of edges
63 method : numba.njit'ed function
64 """
65 kk, blocks_ix, blocks_weights_x = inds_weights
66 # k are indices of dst array
67 # block_i contains indices of src array
68 # block_w contains weights of src array
69 for countk, k in enumerate(kk):
70 block_ix = blocks_ix[countk]
71 block_wx = blocks_weights_x[countk]
72 # Add the values and weights per cell in multi-dim block
73 count = 0
74 for ix, wx in zip(block_ix, block_wx):
75 if ix < 0:
76 break
77 values[count] = src[ix]
78 weights[count] = wx
79 count += 1
81 # aggregate
82 dst[k] = method(values[:count], weights[:count])
84 # reset storage
85 values[:count] = 0
86 weights[:count] = 0
88 return dst
91@numba.njit(cache=True)
92def _regrid_2d(src, dst, values, weights, method, *inds_weights):
93 """
94 numba compiled function to regrid in three dimensions
96 Parameters
97 ----------
98 src : np.array
99 dst : np.array
100 src_coords : tuple of np.arrays of edges
101 dst_coords : tuple of np.arrays of edges
102 method : numba.njit'ed function
103 """
104 jj, blocks_iy, blocks_weights_y, kk, blocks_ix, blocks_weights_x = inds_weights
106 # j, k are indices of dst array
107 # block_i contains indices of src array
108 # block_w contains weights of src array
109 for countj, j in enumerate(jj):
110 block_iy = blocks_iy[countj]
111 block_wy = blocks_weights_y[countj]
112 for countk, k in enumerate(kk):
113 block_ix = blocks_ix[countk]
114 block_wx = blocks_weights_x[countk]
115 # Add the values and weights per cell in multi-dim block
116 count = 0
117 for iy, wy in zip(block_iy, block_wy):
118 if iy < 0:
119 break
120 for ix, wx in zip(block_ix, block_wx):
121 if ix < 0:
122 break
123 values[count] = src[iy, ix]
124 weights[count] = wy * wx
125 count += 1
127 # aggregate
128 dst[j, k] = method(values[:count], weights[:count])
130 # reset storage
131 values[:count] = 0.0
132 weights[:count] = 0.0
134 return dst
137@numba.njit(cache=True)
138def _regrid_3d(src, dst, values, weights, method, *inds_weights):
139 """
140 numba compiled function to regrid in three dimensions
142 Parameters
143 ----------
144 src : np.array
145 dst : np.array
146 src_coords : tuple of np.arrays of edges
147 dst_coords : tuple of np.arrays of edges
148 method : numba.njit'ed function
149 """
150 (
151 ii,
152 blocks_iz,
153 blocks_weights_z,
154 jj,
155 blocks_iy,
156 blocks_weights_y,
157 kk,
158 blocks_ix,
159 blocks_weights_x,
160 ) = inds_weights
162 # i, j, k are indices of dst array
163 # block_i contains indices of src array
164 # block_w contains weights of src array
165 for counti, i in enumerate(ii):
166 block_iz = blocks_iz[counti]
167 block_wz = blocks_weights_z[counti]
168 for countj, j in enumerate(jj):
169 block_iy = blocks_iy[countj]
170 block_wy = blocks_weights_y[countj]
171 for countk, k in enumerate(kk):
172 block_ix = blocks_ix[countk]
173 block_wx = blocks_weights_x[countk]
174 # Add the values and weights per cell in multi-dim block
175 count = 0
176 for iz, wz in zip(block_iz, block_wz):
177 if iz < 0:
178 break
179 for iy, wy in zip(block_iy, block_wy):
180 if iy < 0:
181 break
182 for ix, wx in zip(block_ix, block_wx):
183 if ix < 0:
184 break
185 values[count] = src[iz, iy, ix]
186 weights[count] = wz * wy * wx
187 count += 1
189 # aggregate
190 dst[i, j, k] = method(values[:count], weights[:count])
192 # reset storage
193 values[:count] = 0.0
194 weights[:count] = 0.0
196 return dst
199@numba.njit
200def _iter_regrid(iter_src, iter_dst, alloc_len, regrid_function, *inds_weights):
201 n_iter = iter_src.shape[0]
202 # Pre-allocate temporary storage arrays
203 values = np.zeros(alloc_len)
204 weights = np.zeros(alloc_len)
205 for i in range(n_iter):
206 iter_dst[i, ...] = regrid_function(
207 iter_src[i, ...], iter_dst[i, ...], values, weights, *inds_weights
208 )
209 return iter_dst
212def _jit_regrid(jit_method, ndim_regrid):
213 """
214 Compile a specific aggregation function using the compiled external method
215 Closure avoids numba overhead
216 https://numba.pydata.org/numba-doc/dev/user/faq.html#can-i-pass-a-function-as-an-argument-to-a-jitted-function
217 """
219 @numba.njit
220 def jit_regrid_1d(src, dst, values, weights, *inds_weights):
221 return _regrid_1d(src, dst, values, weights, jit_method, *inds_weights)
223 @numba.njit
224 def jit_regrid_2d(src, dst, values, weights, *inds_weights):
225 return _regrid_2d(src, dst, values, weights, jit_method, *inds_weights)
227 @numba.njit
228 def jit_regrid_3d(src, dst, values, weights, *inds_weights):
229 return _regrid_3d(src, dst, values, weights, jit_method, *inds_weights)
231 if ndim_regrid == 1:
232 jit_regrid = jit_regrid_1d
233 elif ndim_regrid == 2:
234 jit_regrid = jit_regrid_2d
235 elif ndim_regrid == 3:
236 jit_regrid = jit_regrid_3d
237 else:
238 raise NotImplementedError("cannot regrid over more than three dimensions")
240 return jit_regrid
243def _make_regrid(method, ndim_regrid):
244 """
245 Closure avoids numba overhead
246 https://numba.pydata.org/numba-doc/dev/user/faq.html#can-i-pass-a-function-as-an-argument-to-a-jitted-function
247 """
249 # First, compile external method
250 jit_method = numba.njit(method, cache=True)
251 jit_regrid = _jit_regrid(jit_method, ndim_regrid)
253 # Finally, compile the iterating regrid method with the specific aggregation function
254 @numba.njit
255 def iter_regrid(iter_src, iter_dst, alloc_len, *inds_weights):
256 return _iter_regrid(iter_src, iter_dst, alloc_len, jit_regrid, *inds_weights)
258 return iter_regrid
261def _nd_regrid(src, dst, src_coords, dst_coords, iter_regrid, use_relative_weights):
262 """
263 Regrids an ndarray up to maximum 3 dimensions.
264 Dimensionality of regridding is determined by the the length of src_coords
265 (== len(dst_coords)), which has to match with the provide iter_regrid
266 function.
268 Parameters
269 ----------
270 src : np.array
271 dst : np.array
272 src_coords : tuple of np.array
273 dst_coords : tuple of np.array
274 iter_regrid : function, numba compiled
275 """
276 if len(src.shape) != len(dst.shape):
277 raise ValueError("shape mismatch between src and dst")
278 if len(src_coords) != len(dst_coords):
279 raise ValueError("coords mismatch between src and dst")
280 ndim_regrid = len(src_coords)
282 # Determine weights for every regrid dimension, and alloc_len,
283 # the maximum number of src cells that may end up in a single dst cell
284 inds_weights = []
285 alloc_len = 1
286 for src_x, dst_x in zip(src_coords, dst_coords):
287 size, i_w = common._weights_1d(src_x, dst_x, use_relative_weights)
288 for elem in i_w:
289 inds_weights.append(elem)
290 alloc_len *= size
292 iter_src, iter_dst = common._reshape(src, dst, ndim_regrid)
293 iter_dst = iter_regrid(iter_src, iter_dst, alloc_len, *inds_weights)
295 return iter_dst.reshape(dst.shape)
298class Regridder(object):
299 """
300 Object to repeatedly regrid similar objects. Compiles once on first call,
301 can then be repeatedly called without JIT compilation overhead.
303 Attributes
304 ----------
305 method : str, function
306 The method to use for regridding. Default available methods are:
307 ``{"nearest", "multilinear", mean", "harmonic_mean", "geometric_mean",
308 "sum", "minimum", "maximum", "mode", "median", "conductance"}``
309 ndim_regrid : int, optional
310 The number of dimensions over which to regrid. If not provided,
311 ``ndim_regrid`` will be inferred. It serves to prevent regridding over an
312 unexpected number of dimensions; say you want to regrid over only two
313 dimensions. Due to an input error in the coordinates of ``like``, three
314 dimensions may be inferred in the first ``.regrid`` call. An error will
315 be raised if ndim_regrid not match the number of inferred dimensions.
316 Default value is None.
317 use_relative_weights : bool, optional
318 Whether to use relative weights in the regridding method or not.
319 Relative weights are defined as: cell_overlap / source_cellsize, for
320 every axis.
322 This argument should only be used if you are providing your own
323 ``method`` as a function, where the function requires relative, rather
324 than absolute weights (the provided ``conductance`` method requires
325 relative weights, for example). Default value is False.
326 extra_overlap : integer, optional
327 In case of chunked regridding, how many cells of additional overlap is
328 necessary. Linear interpolation requires this for example, as it reaches
329 beyond cell boundaries to compute values. Default value is 0.
331 Examples
332 --------
333 Initialize the Regridder object:
335 >>> mean_regridder = imod.prepare.Regridder(method="mean")
337 Then call the ``regrid`` method to regrid.
339 >>> result = mean_regridder.regrid(source, like)
341 The regridder can be re-used if the number of regridding dimensions
342 match, saving some time by not (re)compiling the regridding method.
344 >>> second_result = mean_regridder.regrid(second_source, like)
346 A one-liner is possible for single use:
348 >>> result = imod.prepare.Regridder(method="mean").regrid(source, like)
350 It's possible to provide your own methods to the ``Regridder``, provided that
351 numba can compile them. They need to take the arguments ``values`` and
352 ``weights``. Make sure they deal with ``nan`` values gracefully!
354 >>> def p30(values, weights):
355 >>> return np.nanpercentile(values, 30)
357 >>> p30_regridder = imod.prepare.Regridder(method=p30)
358 >>> p30_result = p30_regridder.regrid(source, like)
360 The Numba developers maintain a list of support Numpy features here:
361 https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html
363 In general, however, the provided methods should be adequate for your
364 regridding needs.
365 """
367 def __init__(
368 self, method, ndim_regrid=None, use_relative_weights=False, extra_overlap=0
369 ):
370 _method = common._get_method(method, common.METHODS)
371 self.method = _method
372 self.ndim_regrid = ndim_regrid
373 self._first_call = True
374 if _method == common.METHODS["conductance"]:
375 use_relative_weights = True
376 self.use_relative_weights = use_relative_weights
377 if _method == common.METHODS["multilinear"]:
378 extra_overlap = 1
379 self.extra_overlap = extra_overlap
381 def _make_regrid(self):
382 iter_regrid = _make_regrid(self.method, self.ndim_regrid)
383 iter_interp = interpolate._make_interp(self.ndim_regrid)
385 def nd_regrid(src, dst, src_coords_regrid, dst_coords_regrid):
386 return _nd_regrid(
387 src,
388 dst,
389 src_coords_regrid,
390 dst_coords_regrid,
391 iter_regrid,
392 self.use_relative_weights,
393 )
395 def nd_interp(src, dst, src_coords_regrid, dst_coords_regrid):
396 return interpolate._nd_interp(
397 src, dst, src_coords_regrid, dst_coords_regrid, iter_interp
398 )
400 if self.method == "nearest":
401 pass
402 elif self.method == "multilinear":
403 self._nd_regrid = nd_interp
404 else:
405 self._nd_regrid = nd_regrid
407 def _check_ndim_regrid(self, regrid_dims):
408 if not len(regrid_dims) == self.ndim_regrid:
409 raise ValueError(
410 "Number of dimensions to regrid does not match: "
411 f"Regridder.ndim_regrid = {self.ndim_regrid}"
412 )
414 def _prepare(self, regrid_dims):
415 # Create tailor made regridding function: take method and ndims into
416 # account and call it
417 if self.ndim_regrid is None:
418 self.ndim_regrid = len(regrid_dims)
419 else:
420 self._check_ndim_regrid(regrid_dims)
422 if self.method == common.METHODS["conductance"] and len(regrid_dims) > 2:
423 raise ValueError(
424 "The conductance method should not be applied to "
425 "regridding more than two dimensions"
426 )
427 # Create the method.
428 self._make_regrid()
430 @staticmethod
431 def _regrid_info(src, like):
432 # Find coordinates that already match, and those that have to be
433 # regridded, and those that exist in source but not in like (left
434 # untouched)
435 matching_dims, regrid_dims, add_dims = common._match_dims(src, like)
437 # Order dimensions in the right way:
438 # dimensions that are regridded end up at the end for efficient iteration
439 dst_dims = (*add_dims, *matching_dims, *regrid_dims)
440 dims_from_src = (*add_dims, *matching_dims)
441 dims_from_like = tuple(regrid_dims)
443 # Gather destination coordinates
444 dst_da_coords, dst_shape = common._dst_coords(
445 src, like, dims_from_src, dims_from_like
446 )
448 dst_tmp = xr.DataArray(
449 data=dask.array.empty(dst_shape), coords=dst_da_coords, dims=dst_dims
450 )
452 # TODO: check that axes are aligned
453 src_coords_regrid = [common._coord(src, dim) for dim in regrid_dims]
454 dst_coords_regrid = [common._coord(dst_tmp, dim) for dim in regrid_dims]
456 return _RegridInfo(
457 matching_dims=matching_dims,
458 regrid_dims=regrid_dims,
459 add_dims=add_dims,
460 dst_shape=dst_shape,
461 dst_da_coords=dst_da_coords,
462 dst_dims=(*add_dims, *matching_dims, *regrid_dims),
463 src_coords_regrid=src_coords_regrid,
464 dst_coords_regrid=dst_coords_regrid,
465 )
467 def _regrid(self, src, fill_value, info):
468 # Allocate dst
469 dst = np.full(info.dst_shape, fill_value)
470 # No overlap whatsoever, early exit
471 if any(size == 0 for size in src.shape):
472 return dst
474 # Transpose src so that dims to regrid are last
475 src = src.transpose(*info.dst_dims)
477 # Exit early if nothing is to be done
478 if len(info.regrid_dims) == 0:
479 return src.values.copy()
480 else:
481 dst = self._nd_regrid(
482 src.values, dst, info.src_coords_regrid, info.dst_coords_regrid
483 )
484 return dst
486 def _delayed_regrid(self, src, like, fill_value, info):
487 """
488 Deal with chunks in dimensions that will NOT be regridded.
489 """
490 if len(info.add_dims) == 0:
491 return self._chunked_regrid(src, like, fill_value)
493 src_dim_slices = []
494 shape_chunks = []
495 for dim, chunksize in zip(src.dims, src.chunks):
496 if dim in info.add_dims:
497 end = np.cumsum(chunksize)
498 start = end - chunksize
499 src_dim_slices.append([slice(s, e) for s, e in zip(start, end)])
500 shape_chunks.append(len(chunksize))
502 src_expanded_slices = np.stack(
503 [a.ravel() for a in np.meshgrid(*src_dim_slices, indexing="ij")], axis=-1
504 )
505 src_das = common._sel_chunks(src, info.add_dims, src_expanded_slices)
506 n_das = len(src_das)
507 np_collection = np.full(n_das, None)
509 for i, src_da in enumerate(src_das):
510 np_collection[i] = self._chunked_regrid(src_da, like, fill_value)
512 shape_chunks = shape_chunks + [1] * len(info.regrid_dims)
513 reshaped_collection = np.reshape(np_collection, shape_chunks).tolist()
514 data = dask.array.block(reshaped_collection)
515 return data
517 def _chunked_regrid(self, src, like, fill_value):
518 """
519 Deal with chunks in dimensions that will be regridded.
520 """
521 like_expanded_slices, shape_chunks = common._define_slices(src, like)
522 like_das = common._sel_chunks(like, like.dims, like_expanded_slices)
523 n_das = len(like_das)
524 np_collection = np.full(n_das, None)
526 # Regridder should compute first chunk once
527 # so numba has compiled the necessary functions for subsequent chunks
528 for i, dst_da in enumerate(like_das):
529 chunk_src = common._slice_src(src, dst_da, self.extra_overlap)
530 info = self._regrid_info(chunk_src, dst_da)
532 if any(
533 size == 0 for size in chunk_src.shape
534 ): # zero overlap for the chunk, zero size chunk
535 # N.B. Make sure to include chunks=-1, defaults to chunks="auto", which
536 # automatically results in unnecessary, error prone chunks.
537 # TODO: Not covered by tests -- but also rather hard to test.
538 dask_array = dask.array.full(
539 shape=info.dst_shape,
540 fill_value=fill_value,
541 dtype=src.dtype,
542 chunks=-1,
543 )
544 elif self._first_call:
545 # NOT delayed, trigger compilation
546 a = self._regrid(chunk_src, fill_value, info)
547 dask_array = dask.array.from_array(a, chunks=-1)
548 self._first_call = False
549 else:
550 # Alllocation occurs inside
551 a = dask.delayed(self._regrid, pure=True)(chunk_src, fill_value, info)
552 dask_array = dask.array.from_delayed(
553 a, shape=info.dst_shape, dtype=src.dtype
554 )
556 np_collection[i] = dask_array
558 # Determine the shape of the chunks, and reshape so dask.block does the right thing
559 reshaped_collection = np.reshape(np_collection, shape_chunks).tolist()
560 data = dask.array.block(reshaped_collection)
561 return data
563 def regrid(self, source, like, fill_value=np.nan):
564 """
565 Regrid ``source`` along dimensions that ``source`` and ``like`` share.
566 These dimensions will be inferred the first time ``.regrid`` is called
567 for the Regridder object.
569 Following xarray conventions, nodata is assumed to ``np.nan``.
571 Parameters
572 ----------
573 source : xr.DataArray of floats
574 like : xr.DataArray of floats
575 The like array present what the coordinates should look like.
576 fill_value : float
577 The fill_value. Defaults to np.nan
579 Returns
580 -------
581 result : xr.DataArray
582 Regridded result.
583 """
584 if not isinstance(source, xr.DataArray):
585 raise TypeError("source must be a DataArray")
586 if not isinstance(like, xr.DataArray):
587 raise TypeError("like must be a DataArray")
589 # Don't mutate source; src stands for source, dst for destination
590 src = source.copy(deep=False)
591 like = like.copy(deep=False)
592 _, regrid_dims, _ = common._match_dims(src, like)
593 # Exit early if nothing is to be done
594 if len(regrid_dims) == 0:
595 return source.copy(deep=True)
597 # Collect dimensions to flip to make everything ascending
598 src, _ = common._increasing_dims(src, regrid_dims)
599 like, flip_dst = common._increasing_dims(like, regrid_dims)
601 info = self._regrid_info(source, like)
602 # Use xarray for nearest
603 # TODO: replace by more efficient, specialized method
604 if self.method == "nearest":
605 dst = source.reindex_like(like, method="nearest")
606 dst = dst.assign_coords(info.dst_da_coords)
607 return dst
609 # Prepare for regridding; quick checks
610 if self._first_call:
611 self._prepare(info.regrid_dims)
612 self._check_ndim_regrid(info.regrid_dims)
614 if src.chunks is None:
615 src = common._slice_src(src, like, self.extra_overlap)
616 # Recollect info with sliced part of src
617 info = self._regrid_info(src, like)
618 data = self._regrid(src, fill_value, info)
619 self._first_call = False
620 else:
621 # Ensure all dimensions have a dx coordinate, so that if the chunks
622 # results in chunks which are size 1 along a dimension, the cellsize
623 # can still be determined.
624 src = common._set_cellsizes(src, info.regrid_dims)
625 like = common._set_cellsizes(like, info.regrid_dims)
626 data = self._delayed_regrid(src, like, fill_value, info)
628 dst = xr.DataArray(data=data, coords=info.dst_da_coords, dims=info.dst_dims)
629 # Replace equidistant cellsize arrays by scalar values
630 dst = common._set_scalar_cellsizes(dst)
632 # Flip dimensions to return as like
633 for dim in flip_dst:
634 dst = dst.sel({dim: slice(None, None, -1)})
636 # Transpose to original dimension coordinates
637 return dst.transpose(*source.dims)