Coverage for C:\src\imod-python\imod\prepare\common.py: 97%
357 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 11:41 +0200
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 11:41 +0200
1"""
2Common methods used for interpolation, voxelization.
4Includes methods for dealing with different coordinates and dimensions of the
5xarray.DataArrays, as well as aggregation methods operating on weights and
6values.
7"""
9from typing import Any
11import cftime
12import numba
13import numpy as np
15import imod
18@numba.njit
19def _starts(src_x, dst_x):
20 """
21 Calculate regridding weights for a single dimension
23 Parameters
24 ----------
25 src_x : np.array
26 vertex coordinates of source
27 dst_x: np.array
28 vertex coordinates of destination
29 """
30 i = 0
31 j = 0
32 while i < dst_x.size - 1:
33 x = dst_x[i]
34 while j < src_x.size:
35 if src_x[j] > x:
36 out = max(j - 1, 0)
37 yield (i, out)
38 break
39 else:
40 j += 1
41 i += 1
44def _weights_1d(src_x, dst_x, use_relative_weights=False):
45 """
46 Calculate regridding weights and indices for a single dimension
48 Parameters
49 ----------
50 src_x : np.array
51 vertex coordinates of source
52 dst_x: np.array
53 vertex coordinates of destination
55 Returns
56 -------
57 max_len : int
58 maximum number of source cells to a single destination cell for this
59 dimension
60 dst_inds : list of int
61 destination cell index
62 src_inds: list of list of int
63 source cell index, per destination index
64 weights : list of list of float
65 weight of source cell, per destination index
66 """
67 max_len = 0
68 dst_inds = []
69 src_inds = []
70 weights = []
71 rel_weights = []
73 # i is index of dst
74 # j is index of src
75 for i, j in _starts(src_x, dst_x):
76 dst_x0 = dst_x[i]
77 dst_x1 = dst_x[i + 1]
79 _inds = []
80 _weights = []
81 _rel_weights = []
82 has_value = False
83 while j < src_x.size - 1:
84 src_x0 = src_x[j]
85 src_x1 = src_x[j + 1]
86 overlap = _overlap((dst_x0, dst_x1), (src_x0, src_x1))
87 # No longer any overlap, continue to next dst cell
88 if overlap == 0:
89 break
90 else:
91 has_value = True
92 _inds.append(j)
93 _weights.append(overlap)
94 relative_overlap = overlap / (src_x1 - src_x0)
95 _rel_weights.append(relative_overlap)
96 j += 1
97 if has_value:
98 dst_inds.append(i)
99 src_inds.append(_inds)
100 weights.append(_weights)
101 rel_weights.append(_rel_weights)
102 # Save max number of source cells
103 # So we know how much to pre-allocate later on
104 inds_len = len(_inds)
105 if inds_len > max_len:
106 max_len = inds_len
108 # Convert all output to numpy arrays
109 # numba does NOT like arrays or lists in tuples
110 # Compilation time goes through the roof
111 nrow = len(dst_inds)
112 ncol = max_len
113 np_dst_inds = np.array(dst_inds)
115 np_src_inds = np.full((nrow, ncol), -1)
116 for i in range(nrow):
117 for j, ind in enumerate(src_inds[i]):
118 np_src_inds[i, j] = ind
120 np_weights = np.full((nrow, ncol), 0.0)
121 if use_relative_weights:
122 weights = rel_weights
123 for i in range(nrow):
124 for j, ind in enumerate(weights[i]):
125 np_weights[i, j] = ind
127 return max_len, (np_dst_inds, np_src_inds, np_weights)
130def _reshape(src, dst, ndim_regrid):
131 """
132 If ndim > ndim_regrid, the non regridding dimension are combined into
133 a single dimension, so we can use a single loop, irrespective of the
134 total number of dimensions.
135 (The alternative is pre-writing N for-loops for every N dimension we
136 intend to support.)
137 If ndims == ndim_regrid, all dimensions will be used in regridding
138 in that case no looping over other dimensions is required and we add
139 a dummy dimension here so there's something to iterate over.
140 """
141 src_shape = src.shape
142 dst_shape = dst.shape
143 ndim = len(src_shape)
145 if ndim == ndim_regrid:
146 n_iter = 1
147 else:
148 n_iter = int(np.product(src_shape[:-ndim_regrid]))
150 src_itershape = (n_iter, *src_shape[-ndim_regrid:])
151 dst_itershape = (n_iter, *dst_shape[-ndim_regrid:])
153 iter_src = np.reshape(src, src_itershape)
154 iter_dst = np.reshape(dst, dst_itershape)
156 return iter_src, iter_dst
159def _is_subset(a1, a2):
160 if np.in1d(a2, a1).all():
161 # This means all are present
162 # now check if it's an actual subset
163 # Generate number, and fetch only those present
164 idx = np.arange(a1.size)[np.in1d(a1, a2)]
165 if idx.size > 1:
166 increment = np.diff(idx)
167 # If the maximum increment is only 1, it's a subset
168 if increment.max() == 1:
169 return True
170 return False
173def _match_dims(src, like):
174 """
175 Parameters
176 ----------
177 source : xr.DataArray
178 The source DataArray to be regridded
179 like : xr.DataArray
180 Example DataArray that shows what the resampled result should look like
181 in terms of coordinates. `source` is regridded along dimensions of `like`
182 that have the same name, but have different values.
184 Returns
185 -------
186 matching_dims, regrid_dims, add_dims : tuple of lists
187 matching_dims: dimensions along which the coordinates match exactly
188 regrid_dims: dimensions along which source will be regridded
189 add_dims: dimensions that are not present in like
191 """
192 # TODO: deal with different extent?
193 # Do another check if not identical
194 # Check if subset or superset?
195 matching_dims = []
196 regrid_dims = []
197 add_dims = []
198 for dim in src.dims:
199 if dim not in like.dims:
200 add_dims.append(dim)
201 elif src[dim].size == 0: # zero overlap
202 regrid_dims.append(dim)
203 else:
204 try:
205 a1 = _coord(src, dim)
206 a2 = _coord(like, dim)
207 if np.array_equal(a1, a2) or _is_subset(a1, a2):
208 matching_dims.append(dim)
209 else:
210 regrid_dims.append(dim)
211 except TypeError:
212 first_type = type(like[dim].values[0])
213 if issubclass(first_type, (cftime.datetime, np.datetime64)):
214 raise RuntimeError(
215 "cannot regrid over datetime dimensions. "
216 "Use xarray.Dataset.resample() instead"
217 )
219 ndim_regrid = len(regrid_dims)
220 # Check number of dimension to regrid
221 if ndim_regrid > 3:
222 raise NotImplementedError("cannot regrid over more than three dimensions")
224 return matching_dims, regrid_dims, add_dims
227def _increasing_dims(da, dims):
228 flip_dims = []
229 for dim in dims:
230 if not da.indexes[dim].is_monotonic_increasing:
231 flip_dims.append(dim)
232 da = da.isel({dim: slice(None, None, -1)})
233 return da, flip_dims
236def _selection_indices(src_x, xmin, xmax, extra_overlap):
237 """Left-inclusive"""
238 # Extra overlap is needed, for example with (multi)linear interpolation
239 # We simply enlarge the slice at the start and at the end.
240 i0 = max(0, np.searchsorted(src_x, xmin, side="right") - 1 - extra_overlap)
241 i1 = np.searchsorted(src_x, xmax, side="left") + extra_overlap
242 return i0, i1
245def _slice_src(src, like, extra_overlap):
246 """
247 Make sure src matches dst in dims that do not have to be regridded
248 """
249 matching_dims, regrid_dims, _ = _match_dims(src, like)
250 dims = matching_dims + regrid_dims
252 slices = {}
253 for dim in dims:
254 # Generate vertices
255 src_x = _coord(src, dim)
256 _, xmin, xmax = imod.util.spatial.coord_reference(like[dim])
257 i0, i1 = _selection_indices(src_x, xmin, xmax, extra_overlap)
258 slices[dim] = slice(i0, i1)
259 return src.isel(slices)
262def _dst_coords(src, like, dims_from_src, dims_from_like):
263 """
264 Gather destination coordinates
265 """
267 dst_da_coords = {}
268 dst_shape = []
269 # TODO: do some more checking, more robust handling
270 like_coords = dict(like.coords)
271 for dim in dims_from_src:
272 try:
273 like_coords.pop(dim)
274 except KeyError:
275 pass
276 dst_da_coords[dim] = src[dim].values
277 dst_shape.append(src[dim].size)
278 for dim in dims_from_like:
279 try:
280 like_coords.pop(dim)
281 except KeyError:
282 pass
283 dst_da_coords[dim] = like[dim].values
284 dst_shape.append(like[dim].size)
286 dst_da_coords.update(like_coords)
287 return dst_da_coords, dst_shape
290def _check_monotonic(dxs, dim):
291 # use xor to check if one or the other
292 if not ((dxs > 0.0).all() ^ (dxs < 0.0).all()):
293 raise ValueError(f"{dim} is not only increasing or only decreasing")
296def _set_cellsizes(da, dims):
297 for dim in dims:
298 dx_string = f"d{dim}"
299 if dx_string not in da.coords:
300 dx, _, _ = imod.util.spatial.coord_reference(da.coords[dim])
301 if isinstance(dx, (int, float)):
302 dx = np.full(da.coords[dim].size, dx)
303 da = da.assign_coords({dx_string: (dim, dx)})
304 return da
307def _set_scalar_cellsizes(da):
308 for dim in da.dims:
309 dx_string = f"d{dim}"
310 if dx_string in da.coords:
311 dx = da.coords[dx_string]
312 # Ensure no leftover coordinates in scalar
313 if dx.ndim == 0: # Catch case where dx already is a scalar
314 dx_scalar = dx.values[()]
315 else:
316 dx_scalar = dx.values[0]
317 if np.allclose(dx, dx_scalar):
318 da = da.assign_coords({dx_string: dx_scalar})
319 return da
322def _coord(da, dim):
323 """
324 Transform N xarray midpoints into N + 1 vertex edges
325 """
326 delta_dim = "d" + dim # e.g. dx, dy, dz, etc.
328 # If empty array, return empty
329 if da[dim].size == 0:
330 return np.array(())
332 if delta_dim in da.coords: # equidistant or non-equidistant
333 dx = da[delta_dim].values
334 if dx.shape == () or dx.shape == (1,): # scalar -> equidistant
335 dxs = np.full(da[dim].size, dx)
336 else: # array -> non-equidistant
337 dxs = dx
338 _check_monotonic(dxs, dim)
340 else: # not defined -> equidistant
341 if da[dim].size == 1:
342 raise ValueError(
343 f"DataArray has size 1 along {dim}, so cellsize must be provided"
344 " as a coordinate."
345 )
346 dxs = np.diff(da[dim].values)
347 dx = dxs[0]
348 atolx = abs(1.0e-4 * dx)
349 if not np.allclose(dxs, dx, atolx):
350 raise ValueError(
351 f"DataArray has to be equidistant along {dim}, or cellsizes"
352 " must be provided as a coordinate."
353 )
354 dxs = np.full(da[dim].size, dx)
356 dxs = np.abs(dxs)
357 x = da[dim].values
358 if not da.indexes[dim].is_monotonic_increasing:
359 x = x[::-1]
360 dxs = dxs[::-1]
362 # This assumes the coordinate to be monotonic increasing
363 x0 = x[0] - 0.5 * dxs[0]
364 x = np.full(dxs.size + 1, x0)
365 x[1:] += np.cumsum(dxs)
366 return x
369def _define_single_dim_slices(src_x, dst_x, chunksizes):
370 n = len(chunksizes)
371 if not n > 0:
372 raise ValueError("n must be larger than zero")
373 if n == 1:
374 return [slice(None, None)]
376 chunk_indices = np.full(n + 1, 0)
377 chunk_indices[1:] = np.cumsum(chunksizes)
378 # Find locations to cut.
379 src_chunk_x = src_x[chunk_indices]
380 if dst_x[0] < src_chunk_x[0]:
381 src_chunk_x[0] = dst_x[0]
382 if dst_x[-1] > src_chunk_x[-1]:
383 src_chunk_x[-1] = dst_x[-1]
384 # Destinations should NOT have any overlap
385 # Sources may have overlap
386 # We find the most suitable places to cut.
387 dst_i = np.searchsorted(dst_x, src_chunk_x, "left")
388 dst_i[dst_i > dst_x.size - 1] = dst_x.size - 1
390 # Create slices, but only if start and end are different
391 # (otherwise, the slice would be empty)
392 dst_slices = [slice(s, e) for s, e in zip(dst_i[:-1], dst_i[1:]) if s != e]
393 return dst_slices
396def _define_slices(src, like):
397 """
398 Defines the slices for every dimension, based on the chunks that are
399 present within src.
401 First, we get a single list of chunks per dimension.
402 Next, these are expanded into an N-dimensional array, equal to the number
403 of dimensions that have chunks.
404 Finally, these arrays are ravelled, and stacked for easier iteration.
405 """
406 dst_dim_slices = []
407 dst_chunks_shape = []
408 for dim, chunksizes in zip(src.dims, src.chunks):
409 if dim in like.dims:
410 dst_slices = _define_single_dim_slices(
411 _coord(src, dim), _coord(like, dim), chunksizes
412 )
413 dst_dim_slices.append(dst_slices)
414 dst_chunks_shape.append(len(dst_slices))
416 dst_expanded_slices = np.stack(
417 [a.ravel() for a in np.meshgrid(*dst_dim_slices, indexing="ij")], axis=-1
418 )
419 return dst_expanded_slices, dst_chunks_shape
422def _sel_chunks(da, dims, expanded_slices):
423 """
424 Using the slices created with the functions above, use xarray's index
425 selection methods to create a list of "like" DataArrays which are used
426 to inform the regridding. During the regrid() call of the
427 imod.prepare.Regridder object, data from the input array is selected,
428 ideally one chunk at time, or 2 ** ndim_chunks if there is overlap
429 required due to cellsize differences.
430 """
431 das = []
432 for dim_slices in expanded_slices:
433 slice_dict = {}
434 for dim, dim_slice in zip(dims, dim_slices):
435 slice_dict[dim] = dim_slice
436 das.append(da.isel(**slice_dict))
437 return das
440def _get_method(method, methods):
441 if isinstance(method, str):
442 try:
443 _method = methods[method]
444 except KeyError as e:
445 raise ValueError(
446 "Invalid regridding method. Available methods are: {}".format(
447 methods.keys()
448 )
449 ) from e
450 elif callable(method):
451 _method = method
452 else:
453 raise TypeError("method must be a string or rasterio.enums.Resampling")
454 return _method
457@numba.njit
458def _overlap(a, b):
459 return max(0, min(a[1], b[1]) - max(a[0], b[0]))
462def mean(values, weights):
463 vsum = 0.0
464 wsum = 0.0
465 for i in range(values.size):
466 v = values[i]
467 w = weights[i]
468 if np.isnan(v):
469 continue
470 vsum += w * v
471 wsum += w
472 if wsum == 0:
473 return np.nan
474 else:
475 return vsum / wsum
478def harmonic_mean(values, weights):
479 v_agg = 0.0
480 w_sum = 0.0
481 for i in range(values.size):
482 v = values[i]
483 w = weights[i]
484 if np.isnan(v) or v == 0:
485 continue
486 if w > 0:
487 w_sum += w
488 v_agg += w / v
489 if v_agg == 0 or w_sum == 0:
490 return np.nan
491 else:
492 return w_sum / v_agg
495def geometric_mean(values, weights):
496 v_agg = 0.0
497 w_sum = 0.0
499 # Compute sum to ormalize weights to avoid tiny or huge values in exp
500 normsum = 0.0
501 for i in range(values.size):
502 normsum += weights[i]
503 # Early return if no values
504 if normsum == 0:
505 return np.nan
507 for i in range(values.size):
508 w = weights[i] / normsum
509 v = values[i]
510 # Skip if v == 0, v is NaN or w == 0 (no contribution)
511 if v > 0 and w > 0:
512 v_agg += w * np.log(abs(v))
513 w_sum += w
514 # Do not reduce over negative values: would require complex numbers.
515 elif v < 0:
516 return np.nan
518 if w_sum == 0:
519 return np.nan
520 else:
521 return np.exp((1.0 / w_sum) * v_agg)
524def sum(values, weights):
525 v_sum = 0.0
526 w_sum = 0.0
527 for i in range(values.size):
528 v = values[i]
529 w = weights[i]
530 if np.isnan(v):
531 continue
532 v_sum += v
533 w_sum += w
534 if w_sum == 0:
535 return np.nan
536 else:
537 return v_sum
540def minimum(values, weights):
541 return np.nanmin(values)
544def maximum(values, weights):
545 return np.nanmax(values)
548def mode(values, weights):
549 # Area weighted mode
550 # Reuse weights to do counting: no allocations
551 # The alternative is defining a separate frequency array in which to add
552 # the weights. This implementation is less efficient in terms of looping.
553 # With many unique values, it keeps having to loop through a big part of
554 # the weights array... but it would do so with a separate frequency array
555 # as well. There are somewhat more elements to traverse in this case.
556 s = values.size
557 w_sum = 0
558 for i in range(s):
559 v = values[i]
560 w = weights[i]
561 if np.isnan(v):
562 continue
563 w_sum += 1
564 for j in range(i): # Compare with previously found values
565 if values[j] == v: # matches previous value
566 weights[j] += w # increase previous weight
567 break
569 if w_sum == 0: # It skipped everything: only nodata values
570 return np.nan
571 else: # Find value with highest frequency
572 w_max = 0
573 for i in range(s):
574 w = weights[i]
575 if w > w_max:
576 w_max = w
577 v = values[i]
578 return v
581def median(values, weights):
582 return np.nanpercentile(values, 50)
585def conductance(values, weights):
586 v_agg = 0.0
587 w_sum = 0.0
588 for i in range(values.size):
589 v = values[i]
590 w = weights[i]
591 if np.isnan(v):
592 continue
593 v_agg += v * w
594 w_sum += w
595 if w_sum == 0:
596 return np.nan
597 else:
598 return v_agg
601def max_overlap(values, weights):
602 max_w = 0.0
603 v = np.nan
604 for i in range(values.size):
605 w = weights[i]
606 if w > max_w:
607 max_w = w
608 v = values[i]
609 return v
612METHODS: dict[str, Any] = {
613 "nearest": "nearest",
614 "multilinear": "multilinear",
615 "mean": mean,
616 "harmonic_mean": harmonic_mean,
617 "geometric_mean": geometric_mean,
618 "sum": sum,
619 "minimum": minimum,
620 "maximum": maximum,
621 "mode": mode,
622 "median": median,
623 "conductance": conductance,
624 "max_overlap": max_overlap,
625}