Coverage for C:\src\imod-python\imod\typing\structured.py: 92%

128 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-08 13:27 +0200

1# %% 

2 

3import itertools 

4from collections import defaultdict 

5from typing import Any, DefaultDict, Dict, List, Set, Tuple 

6 

7import dask 

8import numpy as np 

9import xarray as xr 

10 

11# %% 

12 

13 

14def check_dtypes(das: List[xr.DataArray]) -> None: 

15 """Check whether the dtypes of all arrays are the same.""" 

16 dtypes = set(da.dtype for da in das) 

17 if len(dtypes) != 1: 

18 raise TypeError(f"DataArrays do not match in dtype: {dtypes}") 

19 return 

20 

21 

22def _is_nonunique_dimsize(sizes: Set[int]) -> bool: 

23 return len(sizes) != 1 

24 

25 

26def check_sizes(sizes: DefaultDict[str, Set[int]], attribute: str) -> None: 

27 """Utility for checking a dict of dimension names and sizes. Skips x and y.""" 

28 sizes.pop("x", None) 

29 sizes.pop("y", None) 

30 conflicting = {k: v for k, v in sizes.items() if _is_nonunique_dimsize(v)} 

31 if conflicting: 

32 message = ( 

33 f"DataArrays do not match in {attribute} along dimension(s):\n" 

34 + "\n".join([f" {k}: {v}" for k, v in conflicting.items()]) 

35 ) 

36 raise ValueError(message) 

37 return 

38 

39 

40def check_dims(das: List[xr.DataArray]) -> None: 

41 all_dims = set(da.dims for da in das) 

42 if len(all_dims) != 1: 

43 raise ValueError( 

44 f"All DataArrays should have exactly the same dimensions. Found: {all_dims}" 

45 ) 

46 last_dims = das[0].dims[-2:] 

47 if not last_dims == ("y", "x"): 

48 raise ValueError(f'Last dimensions must be ("y", "x"). Found: {last_dims}') 

49 check_dim_sizes(das) 

50 

51 

52def check_dim_sizes(das: List[xr.DataArray]) -> None: 

53 """Check whether all non-xy dims are equally sized.""" 

54 sizes = defaultdict(set) 

55 for da in das: 

56 for key, value in da.sizes.items(): 

57 sizes[key].add(value) 

58 check_sizes(sizes, "size") 

59 return 

60 

61 

62def check_coords(das: List[xr.DataArray]): 

63 def drop_xy(coords) -> Dict[str, Any]: 

64 coords = dict(coords) 

65 coords.pop("y") 

66 coords.pop("x") 

67 return xr.Coordinates(coords) 

68 

69 first_coords = drop_xy(das[0].coords) 

70 disjoint = [ 

71 i + 1 

72 for i, da in enumerate(das[1:]) 

73 if not first_coords.equals(drop_xy(da.coords)) 

74 ] 

75 if disjoint: 

76 raise ValueError( 

77 f"Non x-y coordinates do not match for partition 0 with partitions: {disjoint}" 

78 ) 

79 return 

80 

81 

82def check_chunk_sizes(das: List[xr.DataArray]) -> None: 

83 """Check whether all chunks are equal on non-xy dims.""" 

84 chunks = [da.chunks for da in das] 

85 iterator = (item is None for item in chunks) 

86 allnone = all(iterator) 

87 if allnone: 

88 return 

89 if any(iterator) != allnone: 

90 raise ValueError("Some DataArrays are chunked, while others are not.") 

91 

92 sizes = defaultdict(set) 

93 for da in das: 

94 for key, value in zip(da.dims, da.chunks): 

95 sizes[key].add(value) 

96 check_sizes(sizes, "chunks") 

97 return 

98 

99 

100def merge_arrays( 

101 arrays: List[np.ndarray], 

102 ixs: List[np.ndarray], 

103 iys: List[np.ndarray], 

104 yx_shape: Tuple[int, int], 

105) -> np.ndarray: 

106 """ 

107 Merge the arrays in the last two (y, x) dimensions. 

108 

109 Parameters 

110 ---------- 

111 arrays: list of N np.ndarray 

112 ixs: list of N np.ndarray of int 

113 The i-th element are the x indices of the i-th array into the merged 

114 array. 

115 iys: list of N np.ndarray of int 

116 The i-th element are the y indices of the i-th array into the merged 

117 array. 

118 yx_shape: tuple of int 

119 The number of rows and columns of the merged array. 

120 

121 Returns 

122 ------- 

123 merged: np.ndarray 

124 """ 

125 first = arrays[0] 

126 shape = first.shape[:-2] + yx_shape 

127 out = np.full(shape, np.nan, dtype=first.dtype) 

128 for a, ix, iy in zip(arrays, ixs, iys): 

129 ysize, xsize = a.shape[-2:] 

130 # Create view of partition, see: 

131 # https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding 

132 out_partition_view = out[..., iy : iy + ysize, ix : ix + xsize] 

133 # Assign active values to view (updates `out` inplace) 

134 out_partition_view[...] = np.where(~np.isnan(a), a, out_partition_view) 

135 return out 

136 

137 

138def _unique_coords(das: List[xr.DataArray], dim: str) -> xr.DataArray: 

139 """Collect unique coords in list of dataarrays""" 

140 return np.unique(np.concatenate([da.coords[dim].values for da in das])) 

141 

142 

143def _is_nonequidistant_coord(da: xr.DataArray, dim: str) -> bool: 

144 return (dim in da.coords) and (da.coords[dim].size != 1) 

145 

146 

147def _merge_nonequidistant_coords( 

148 das: List[xr.DataArray], coordname: str, indices: List[np.ndarray], nsize: int 

149): 

150 dtype = das[0].coords[coordname].dtype 

151 out = np.full((nsize,), np.nan, dtype=dtype) 

152 for da, index in zip(das, indices): 

153 coords = da.coords[coordname] 

154 out[index : index + coords.size] = coords.values 

155 return out 

156 

157 

158def _merge_partitions(das: List[xr.DataArray]) -> xr.DataArray: 

159 # Do some input checking 

160 check_dtypes(das) 

161 check_dims(das) 

162 check_chunk_sizes(das) 

163 check_coords(das) 

164 

165 # Create the x and y coordinates of the merged grid. 

166 x = _unique_coords(das, "x") 

167 y = _unique_coords(das, "y") 

168 nrow = y.size 

169 ncol = x.size 

170 # Compute the indices for where the different subdomain parts belong 

171 # in the merged grid. 

172 ixs = [np.searchsorted(x, da.x.values[0], side="left") for da in das] 

173 iys = [nrow - np.searchsorted(y, da.y.values[0], side="right") for da in das] 

174 yx_shape = (nrow, ncol) 

175 

176 # Collect coordinates 

177 first = das[0] 

178 coords = dict(first.coords) 

179 coords["x"] = x 

180 coords["y"] = y[::-1] 

181 if _is_nonequidistant_coord(first, "dx"): 

182 coords["dx"] = ("x", _merge_nonequidistant_coords(das, "dx", ixs, ncol)) 

183 if _is_nonequidistant_coord(first, "dy"): 

184 coords["dy"] = ("y", _merge_nonequidistant_coords(das, "dy", iys, nrow)) 

185 

186 arrays = [da.data for da in das] 

187 if first.chunks is None: 

188 # If the data is in memory, merge all at once. 

189 data = merge_arrays(arrays, ixs, iys, yx_shape) 

190 else: 

191 # Iterate over the chunks of the dask array. Collect the chunks 

192 # from every partition and merge them, chunk by chunk. 

193 # The delayed merged result is stored as a flat list. These can 

194 # be directly concatenated into a new dask array if chunking occurs 

195 # on only the first dimension (e.g. time), but not if chunks exist 

196 # in multiple dimensions (e.g. time and layer). 

197 # 

198 # dask.array.block() is capable of concatenating over multiple 

199 # dimensions if we feed it a nested list of lists of dask arrays. 

200 # This is more easily represented by a numpy array of objects 

201 # (dask arrays), since numpy has nice tooling for reshaping. 

202 # 

203 # Normally, we'd append to a list, then convert to numpy array and 

204 # reshape. However, numpy attempts to join a list of dask arrays into 

205 # a single large numpy array when initialized. This behavior is not 

206 # triggered when setting individual elements of the array, so we 

207 # create the numpy array in advance and set its elements. 

208 

209 block_shape = das[0].data.blocks.shape[:-2] 

210 merged_blocks = np.empty(np.prod(block_shape), dtype=object) 

211 dimension_ranges = [range(size) for size in block_shape] 

212 for i, index in enumerate(itertools.product(*dimension_ranges)): 

213 # This is a workaround for python 3.10 

214 # FUTURE: can be rewritten to arr.blocks[*index, ...] in python 3.11 

215 index_with_ellipsis = tuple(index) + (...,) 

216 # arr.blocks provides us access to the chunks of the array. 

217 arrays_to_merge = [arr.blocks[index_with_ellipsis] for arr in arrays] 

218 delayed_merged = dask.delayed(merge_arrays)( 

219 arrays_to_merge, ixs, iys, yx_shape 

220 ) 

221 dask_merged = dask.array.from_delayed( 

222 delayed_merged, 

223 shape=arrays_to_merge[0].shape[:-2] + yx_shape, 

224 dtype=first.dtype, 

225 ) 

226 merged_blocks[i] = dask_merged 

227 

228 # After merging, the xy chunks are always (1, 1) 

229 reshaped = merged_blocks.reshape(block_shape + (1, 1)) 

230 data = dask.array.block(reshaped.tolist()) 

231 

232 return xr.DataArray( 

233 data=data, 

234 coords=coords, 

235 dims=first.dims, 

236 ) 

237 

238 

239def merge_partitions( 

240 das: List[xr.DataArray | xr.Dataset], 

241) -> xr.Dataset: 

242 first_item = das[0] 

243 if isinstance(first_item, xr.Dataset): 

244 unique_keys = set([key for da in das for key in da.keys()]) 

245 merged_ls = [] 

246 for key in unique_keys: 

247 merged_ls.append(_merge_partitions([da[key] for da in das]).rename(key)) 

248 return xr.merge(merged_ls) 

249 elif isinstance(first_item, xr.DataArray): 

250 # Store name to rename after concatenation 

251 name = first_item.name 

252 return _merge_partitions(das).to_dataset(name=name) 

253 else: 

254 raise TypeError( 

255 f"Expected type: xr.DataArray or xr.Dataset, got {type(first_item)}" 

256 )