Coverage for C:\src\imod-python\imod\mf6\simulation.py: 96%

485 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 11:25 +0200

1from __future__ import annotations 

2 

3import collections 

4import pathlib 

5import subprocess 

6import warnings 

7from copy import deepcopy 

8from pathlib import Path 

9from typing import Any, Callable, DefaultDict, Iterable, Optional, Union, cast 

10 

11import cftime 

12import dask 

13import jinja2 

14import numpy as np 

15import tomli 

16import tomli_w 

17import xarray as xr 

18import xugrid as xu 

19 

20import imod 

21import imod.logging 

22import imod.mf6.exchangebase 

23from imod.logging import standard_log_decorator 

24from imod.mf6.gwfgwf import GWFGWF 

25from imod.mf6.gwfgwt import GWFGWT 

26from imod.mf6.gwtgwt import GWTGWT 

27from imod.mf6.ims import Solution 

28from imod.mf6.interfaces.imodel import IModel 

29from imod.mf6.interfaces.isimulation import ISimulation 

30from imod.mf6.model import Modflow6Model 

31from imod.mf6.model_gwf import GroundwaterFlowModel 

32from imod.mf6.model_gwt import GroundwaterTransportModel 

33from imod.mf6.multimodel.exchange_creator_structured import ExchangeCreator_Structured 

34from imod.mf6.multimodel.exchange_creator_unstructured import ( 

35 ExchangeCreator_Unstructured, 

36) 

37from imod.mf6.multimodel.modelsplitter import create_partition_info, slice_model 

38from imod.mf6.out import open_cbc, open_conc, open_hds 

39from imod.mf6.package import Package 

40from imod.mf6.ssm import SourceSinkMixing 

41from imod.mf6.statusinfo import NestedStatusInfo 

42from imod.mf6.utilities.mask import _mask_all_models 

43from imod.mf6.utilities.regrid import _regrid_like 

44from imod.mf6.write_context import WriteContext 

45from imod.schemata import ValidationError 

46from imod.typing import GridDataArray, GridDataset 

47from imod.typing.grid import ( 

48 concat, 

49 is_equal, 

50 is_unstructured, 

51 merge_partitions, 

52) 

53 

54OUTPUT_FUNC_MAPPING: dict[str, Callable] = { 

55 "head": open_hds, 

56 "concentration": open_conc, 

57 "budget-flow": open_cbc, 

58 "budget-transport": open_cbc, 

59} 

60 

61OUTPUT_MODEL_MAPPING: dict[ 

62 str, type[GroundwaterFlowModel] | type[GroundwaterTransportModel] 

63] = { 

64 "head": GroundwaterFlowModel, 

65 "concentration": GroundwaterTransportModel, 

66 "budget-flow": GroundwaterFlowModel, 

67 "budget-transport": GroundwaterTransportModel, 

68} 

69 

70 

71def get_models(simulation: Modflow6Simulation) -> dict[str, Modflow6Model]: 

72 return {k: v for k, v in simulation.items() if isinstance(v, Modflow6Model)} 

73 

74 

75def get_packages(simulation: Modflow6Simulation) -> dict[str, Package]: 

76 return { 

77 pkg_name: pkg 

78 for pkg_name, pkg in simulation.items() 

79 if isinstance(pkg, Package) 

80 } 

81 

82 

83class Modflow6Simulation(collections.UserDict, ISimulation): 

84 def _initialize_template(self): 

85 loader = jinja2.PackageLoader("imod", "templates/mf6") 

86 env = jinja2.Environment(loader=loader, keep_trailing_newline=True) 

87 self._template = env.get_template("sim-nam.j2") 

88 

89 def __init__(self, name): 

90 super().__init__() 

91 self.name = name 

92 self.directory = None 

93 self._initialize_template() 

94 

95 def __setitem__(self, key, value): 

96 super().__setitem__(key, value) 

97 

98 def update(self, *args, **kwargs): 

99 for k, v in dict(*args, **kwargs).items(): 

100 self[k] = v 

101 

102 def time_discretization(self, times): 

103 warnings.warn( 

104 f"{self.__class__.__name__}.time_discretization() is deprecated. " 

105 f"In the future call {self.__class__.__name__}.create_time_discretization().", 

106 DeprecationWarning, 

107 ) 

108 self.create_time_discretization(additional_times=times) 

109 

110 def create_time_discretization(self, additional_times, validate: bool = True): 

111 """ 

112 Collect all unique times from model packages and additional given 

113 `times`. These unique times are used as stress periods in the model. All 

114 stress packages must have the same starting time. Function creates 

115 TimeDiscretization object which is set to self["time_discretization"] 

116 

117 The time discretization in imod-python works as follows: 

118 

119 - The datetimes of all packages you send in are always respected 

120 - Subsequently, the input data you use is always included fully as well 

121 - All times are treated as starting times for the stress: a stress is 

122 always applied until the next specified date 

123 - For this reason, a final time is required to determine the length of 

124 the last stress period 

125 - Additional times can be provided to force shorter stress periods & 

126 more detailed output 

127 - Every stress has to be defined on the first stress period (this is a 

128 modflow requirement) 

129 

130 Or visually (every letter a date in the time axes): 

131 

132 >>> recharge a - b - c - d - e - f 

133 >>> river g - - - - h - - - - j 

134 >>> times - - - - - - - - - - - i 

135 >>> model a - b - c h d - e - f i 

136 

137 with the stress periods defined between these dates. I.e. the model 

138 times are the set of all times you include in the model. 

139 

140 Parameters 

141 ---------- 

142 additional_times : str, datetime; or iterable of str, datetimes. 

143 Times to add to the time discretization. At least one single time 

144 should be given, which will be used as the ending time of the 

145 simulation. 

146 

147 Note 

148 ---- 

149 To set the other parameters of the TimeDiscretization object, you have 

150 to set these to the object after calling this function. 

151 

152 Example 

153 ------- 

154 >>> simulation = imod.mf6.Modflow6Simulation("example") 

155 >>> simulation.create_time_discretization(times=["2000-01-01", "2000-01-02"]) 

156 >>> # Set number of timesteps 

157 >>> simulation["time_discretization"]["n_timesteps"] = 5 

158 """ 

159 self.use_cftime = any( 

160 model._use_cftime() 

161 for model in self.values() 

162 if isinstance(model, Modflow6Model) 

163 ) 

164 

165 times = [ 

166 imod.util.time.to_datetime_internal(time, self.use_cftime) 

167 for time in additional_times 

168 ] 

169 for model in self.values(): 

170 if isinstance(model, Modflow6Model): 

171 times.extend(model._yield_times()) 

172 

173 # np.unique also sorts 

174 times = np.unique(np.hstack(times)) 

175 

176 duration = imod.util.time.timestep_duration(times, self.use_cftime) # type: ignore 

177 # Generate time discretization, just rely on default arguments 

178 # Probably won't be used that much anyway? 

179 timestep_duration = xr.DataArray( 

180 duration, coords={"time": np.array(times)[:-1]}, dims=("time",) 

181 ) 

182 self["time_discretization"] = imod.mf6.TimeDiscretization( 

183 timestep_duration=timestep_duration, validate=validate 

184 ) 

185 

186 def render(self, write_context: WriteContext): 

187 """Renders simulation namefile""" 

188 d: dict[str, Any] = {} 

189 models = [] 

190 solutiongroups = [] 

191 for key, value in self.items(): 

192 if isinstance(value, Modflow6Model): 

193 model_name_file = pathlib.Path( 

194 write_context.root_directory / pathlib.Path(f"{key}", f"{key}.nam") 

195 ).as_posix() 

196 models.append((value.model_id, model_name_file, key)) 

197 elif isinstance(value, Package): 

198 if value._pkg_id == "tdis": 

199 d["tdis6"] = f"{key}.tdis" 

200 elif value._pkg_id == "ims": 

201 slnnames = value["modelnames"].values 

202 modeltypes = set() 

203 for name in slnnames: 

204 try: 

205 modeltypes.add(type(self[name])) 

206 except KeyError: 

207 raise KeyError(f"model {name} of {key} not found") 

208 

209 if len(modeltypes) > 1: 

210 raise ValueError( 

211 "Only a single type of model allowed in a solution" 

212 ) 

213 solutiongroups.append(("ims6", f"{key}.ims", slnnames)) 

214 

215 d["models"] = models 

216 if len(models) > 1: 

217 d["exchanges"] = self.get_exchange_relationships() 

218 

219 d["solutiongroups"] = [solutiongroups] 

220 return self._template.render(d) 

221 

222 @standard_log_decorator() 

223 def write( 

224 self, 

225 directory=".", 

226 binary=True, 

227 validate: bool = True, 

228 use_absolute_paths=False, 

229 ): 

230 """ 

231 Write Modflow6 simulation, including assigned groundwater flow and 

232 transport models. 

233 

234 Parameters 

235 ---------- 

236 directory: str, pathlib.Path 

237 Directory to write Modflow 6 simulation to. 

238 binary: ({True, False}, optional) 

239 Whether to write time-dependent input for stress packages as binary 

240 files, which are smaller in size, or more human-readable text files. 

241 validate: ({True, False}, optional) 

242 Whether to validate the Modflow6 simulation, including models, at 

243 write. If True, erronous model input will throw a 

244 ``ValidationError``. 

245 absolute_paths: ({True, False}, optional) 

246 True if all paths written to the mf6 inputfiles should be absolute. 

247 """ 

248 # create write context 

249 write_context = WriteContext(directory, binary, use_absolute_paths) 

250 if self.is_split(): 

251 write_context.is_partitioned = True 

252 

253 # Check models for required content 

254 for key, model in self.items(): 

255 # skip timedis, exchanges 

256 if isinstance(model, Modflow6Model): 

257 model._model_checks(key) 

258 

259 # Generate GWF-GWT exchanges 

260 if gwfgwt_exchanges := self._generate_gwfgwt_exchanges(): 

261 self["gwtgwf_exchanges"] = gwfgwt_exchanges 

262 

263 directory = pathlib.Path(directory) 

264 directory.mkdir(exist_ok=True, parents=True) 

265 

266 # Write simulation namefile 

267 mfsim_content = self.render(write_context) 

268 mfsim_path = directory / "mfsim.nam" 

269 with open(mfsim_path, "w") as f: 

270 f.write(mfsim_content) 

271 

272 # Write time discretization file 

273 self["time_discretization"].write(directory, "time_discretization") 

274 

275 # Write individual models 

276 status_info = NestedStatusInfo("Simulation validation status") 

277 globaltimes = self["time_discretization"]["time"].values 

278 for key, value in self.items(): 

279 model_write_context = write_context.copy_with_new_write_directory( 

280 write_context.simulation_directory 

281 ) 

282 # skip timedis, exchanges 

283 if isinstance(value, Modflow6Model): 

284 status_info.add( 

285 value.write( 

286 modelname=key, 

287 globaltimes=globaltimes, 

288 validate=validate, 

289 write_context=model_write_context, 

290 ) 

291 ) 

292 elif isinstance(value, Package): 

293 if value._pkg_id == "ims": 

294 ims_write_context = write_context.copy_with_new_write_directory( 

295 write_context.simulation_directory 

296 ) 

297 value.write(key, globaltimes, ims_write_context) 

298 elif isinstance(value, list): 

299 for exchange in value: 

300 if isinstance(exchange, imod.mf6.exchangebase.ExchangeBase): 

301 exchange.write( 

302 exchange.package_name(), globaltimes, write_context 

303 ) 

304 

305 if status_info.has_errors(): 

306 raise ValidationError("\n" + status_info.to_string()) 

307 

308 self.directory = directory 

309 

310 def run(self, mf6path: Union[str, Path] = "mf6") -> None: 

311 """ 

312 Run Modflow 6 simulation. This method runs a subprocess calling 

313 ``mf6path``. This argument is set to ``mf6``, which means the Modflow 6 

314 executable is expected to be added to your PATH environment variable. 

315 :doc:`See this writeup how to add Modflow 6 to your PATH on Windows </examples/mf6/index>` 

316 

317 Note that the ``write`` method needs to be called before this method is 

318 called. 

319 

320 Parameters 

321 ---------- 

322 mf6path: Union[str, Path] 

323 Path to the Modflow 6 executable. Defaults to calling ``mf6``. 

324 

325 Examples 

326 -------- 

327 Make sure you write your model first 

328 

329 >>> simulation.write(path/to/model) 

330 >>> simulation.run() 

331 """ 

332 if self.directory is None: 

333 raise RuntimeError(f"Simulation {self.name} has not been written yet.") 

334 with imod.util.cd(self.directory): 

335 result = subprocess.run(mf6path, capture_output=True) 

336 if result.returncode != 0: 

337 raise RuntimeError( 

338 f"Simulation {self.name}: {mf6path} failed to run with returncode " 

339 f"{result.returncode}, and error message:\n\n{result.stdout.decode()} " 

340 ) 

341 

342 def open_head( 

343 self, 

344 dry_nan: bool = False, 

345 simulation_start_time: Optional[np.datetime64] = None, 

346 time_unit: Optional[str] = "d", 

347 ) -> GridDataArray: 

348 """ 

349 Open heads of finished simulation, requires that the ``run`` method has 

350 been called. 

351 

352 The data is lazily read per timestep and automatically converted into 

353 (dense) xr.DataArrays or xu.UgridDataArrays, for DIS and DISV 

354 respectively. The conversion is done via the information stored in the 

355 Binary Grid file (GRB). 

356 

357 Parameters 

358 ---------- 

359 dry_nan: bool, default value: False. 

360 Whether to convert dry values to NaN. 

361 simulation_start_time : Optional datetime 

362 The time and date correpsonding to the beginning of the simulation. 

363 Use this to convert the time coordinates of the output array to 

364 calendar time/dates. time_unit must also be present if this argument is present. 

365 time_unit: Optional str 

366 The time unit MF6 is working in, in string representation. 

367 Only used if simulation_start_time was provided. 

368 Admissible values are: 

369 ns -> nanosecond 

370 ms -> microsecond 

371 s -> second 

372 m -> minute 

373 h -> hour 

374 d -> day 

375 w -> week 

376 Units "month" or "year" are not supported, as they do not represent unambiguous timedelta values durations. 

377 

378 Returns 

379 ------- 

380 head: Union[xr.DataArray, xu.UgridDataArray] 

381 

382 Examples 

383 -------- 

384 Make sure you write and run your model first 

385 

386 >>> simulation.write(path/to/model) 

387 >>> simulation.run() 

388 

389 Then open heads: 

390 

391 >>> head = simulation.open_head() 

392 """ 

393 return self._open_output( 

394 "head", 

395 dry_nan=dry_nan, 

396 simulation_start_time=simulation_start_time, 

397 time_unit=time_unit, 

398 ) 

399 

400 def open_transport_budget( 

401 self, 

402 species_ls: Optional[list[str]] = None, 

403 simulation_start_time: Optional[np.datetime64] = None, 

404 time_unit: Optional[str] = "d", 

405 ) -> GridDataArray | GridDataset: 

406 """ 

407 Open transport budgets of finished simulation, requires that the ``run`` 

408 method has been called. 

409 

410 The data is lazily read per timestep and automatically converted into 

411 (dense) xr.DataArrays or xu.UgridDataArrays, for DIS and DISV 

412 respectively. The conversion is done via the information stored in the 

413 Binary Grid file (GRB). 

414 

415 Parameters 

416 ---------- 

417 species_ls: list of strings, default value: None. 

418 List of species names, which will be used to concatenate the 

419 concentrations along the ``"species"`` dimension, in case the 

420 simulation has multiple species and thus multiple transport models. 

421 If None, transport model names will be used as species names. 

422 

423 Returns 

424 ------- 

425 budget: Dict[str, xr.DataArray|xu.UgridDataArray] 

426 DataArray contains float64 data of the budgets, with dimensions ("time", 

427 "layer", "y", "x"). 

428 

429 """ 

430 return self._open_output( 

431 "budget-transport", 

432 species_ls=species_ls, 

433 simulation_start_time=simulation_start_time, 

434 time_unit=time_unit, 

435 merge_to_dataset=True, 

436 flowja=False, 

437 ) 

438 

439 def open_flow_budget( 

440 self, 

441 flowja: bool = False, 

442 simulation_start_time: Optional[np.datetime64] = None, 

443 time_unit: Optional[str] = "d", 

444 ) -> GridDataArray | GridDataset: 

445 """ 

446 Open flow budgets of finished simulation, requires that the ``run`` 

447 method has been called. 

448 

449 The data is lazily read per timestep and automatically converted into 

450 (dense) xr.DataArrays or xu.UgridDataArrays, for DIS and DISV 

451 respectively. The conversion is done via the information stored in the 

452 Binary Grid file (GRB). 

453 

454 The ``flowja`` argument controls whether the flow-ja-face array (if 

455 present) is returned in grid form as "as is". By default 

456 ``flowja=False`` and the array is returned in "grid form", meaning: 

457 

458 * DIS: in right, front, and lower face flow. All flows are placed in 

459 the cell. 

460 * DISV: in horizontal and lower face flow.the horizontal flows are 

461 placed on the edges and the lower face flow is placed on the faces. 

462 

463 When ``flowja=True``, the flow-ja-face array is returned as it is found in 

464 the CBC file, with a flow for every cell to cell connection. Additionally, 

465 a ``connectivity`` DataArray is returned describing for every cell (n) its 

466 connected cells (m). 

467 

468 Parameters 

469 ---------- 

470 flowja: bool, default value: False 

471 Whether to return the flow-ja-face values "as is" (``True``) or in a 

472 grid form (``False``). 

473 

474 Returns 

475 ------- 

476 budget: Dict[str, xr.DataArray|xu.UgridDataArray] 

477 DataArray contains float64 data of the budgets, with dimensions ("time", 

478 "layer", "y", "x"). 

479 

480 Examples 

481 -------- 

482 Make sure you write and run your model first 

483 

484 >>> simulation.write(path/to/model) 

485 >>> simulation.run() 

486 

487 Then open budgets: 

488 

489 >>> budget = simulation.open_flow_budget() 

490 

491 Check the contents: 

492 

493 >>> print(budget.keys()) 

494 

495 Get the drainage budget, compute a time mean for the first layer: 

496 

497 >>> drn_budget = budget["drn] 

498 >>> mean = drn_budget.sel(layer=1).mean("time") 

499 """ 

500 return self._open_output( 

501 "budget-flow", 

502 flowja=flowja, 

503 simulation_start_time=simulation_start_time, 

504 time_unit=time_unit, 

505 merge_to_dataset=True, 

506 ) 

507 

508 def open_concentration( 

509 self, 

510 species_ls: Optional[list[str]] = None, 

511 dry_nan: bool = False, 

512 simulation_start_time: Optional[np.datetime64] = None, 

513 time_unit: Optional[str] = "d", 

514 ) -> GridDataArray: 

515 """ 

516 Open concentration of finished simulation, requires that the ``run`` 

517 method has been called. 

518 

519 The data is lazily read per timestep and automatically converted into 

520 (dense) xr.DataArrays or xu.UgridDataArrays, for DIS and DISV 

521 respectively. The conversion is done via the information stored in the 

522 Binary Grid file (GRB). 

523 

524 Parameters 

525 ---------- 

526 species_ls: list of strings, default value: None. 

527 List of species names, which will be used to concatenate the 

528 concentrations along the ``"species"`` dimension, in case the 

529 simulation has multiple species and thus multiple transport models. 

530 If None, transport model names will be used as species names. 

531 dry_nan: bool, default value: False. 

532 Whether to convert dry values to NaN. 

533 

534 Returns 

535 ------- 

536 concentration: Union[xr.DataArray, xu.UgridDataArray] 

537 

538 Examples 

539 -------- 

540 Make sure you write and run your model first 

541 

542 >>> simulation.write(path/to/model) 

543 >>> simulation.run() 

544 

545 Then open concentrations: 

546 

547 >>> concentration = simulation.open_concentration() 

548 """ 

549 return self._open_output( 

550 "concentration", 

551 species_ls=species_ls, 

552 dry_nan=dry_nan, 

553 simulation_start_time=simulation_start_time, 

554 time_unit=time_unit, 

555 ) 

556 

557 def _open_output(self, output: str, **settings) -> GridDataArray | GridDataset: 

558 """ 

559 Opens output of one or multiple models. 

560 

561 Parameters 

562 ---------- 

563 output: str 

564 Output variable name to open 

565 **settings: 

566 Extra settings that need to be passed through to the respective 

567 output function. 

568 """ 

569 modeltype = OUTPUT_MODEL_MAPPING[output] 

570 modelnames = self.get_models_of_type(modeltype._model_id).keys() 

571 if len(modelnames) == 0: 

572 modeltype = OUTPUT_MODEL_MAPPING[output] 

573 raise ValueError( 

574 f"Could not find any models of appropriate type for {output}, " 

575 f"make sure a model of type {modeltype} is assigned to simulation." 

576 ) 

577 

578 if output in ["head", "budget-flow"]: 

579 return self._open_single_output(list(modelnames), output, **settings) 

580 elif output in ["concentration", "budget-transport"]: 

581 return self._concat_species(output, **settings) 

582 else: 

583 raise RuntimeError( 

584 f"Unexpected error when opening {output} for {modelnames}" 

585 ) 

586 return 

587 

588 def _open_single_output( 

589 self, modelnames: list[str], output: str, **settings 

590 ) -> GridDataArray | GridDataset: 

591 """ 

592 Open single output, e.g. concentration of single species, or heads. This 

593 can be output of partitioned models that need to be merged. 

594 """ 

595 if len(modelnames) == 0: 

596 modeltype = OUTPUT_MODEL_MAPPING[output] 

597 raise ValueError( 

598 f"Could not find any models of appropriate type for {output}, " 

599 f"make sure a model of type {modeltype} is assigned to simulation." 

600 ) 

601 elif len(modelnames) == 1: 

602 modelname = next(iter(modelnames)) 

603 return self._open_single_output_single_model(modelname, output, **settings) 

604 elif self.is_split(): 

605 if "budget" in output: 

606 return self._merge_budgets(modelnames, output, **settings) 

607 else: 

608 return self._merge_states(modelnames, output, **settings) 

609 raise ValueError("error in _open_single_output") 

610 

611 def _merge_states( 

612 self, modelnames: list[str], output: str, **settings 

613 ) -> GridDataArray: 

614 state_partitions = [] 

615 for modelname in modelnames: 

616 state_partitions.append( 

617 self._open_single_output_single_model(modelname, output, **settings) 

618 ) 

619 return merge_partitions(state_partitions) 

620 

621 def _merge_and_assign_exchange_budgets(self, cbc: GridDataset) -> GridDataset: 

622 """ 

623 Merge and assign exchange budgets to cell by cell budgets: 

624 cbc[[gwf-gwf_1, gwf-gwf_3]] to cbc[gwf-gwf] 

625 """ 

626 exchange_names = [ 

627 key 

628 for key in cast(Iterable[str], cbc.keys()) 

629 if (("gwf-gwf" in key) or ("gwt-gwt" in key)) 

630 ] 

631 exchange_budgets = cbc[exchange_names].to_array().sum(dim="variable") 

632 cbc = cbc.drop_vars(exchange_names) 

633 # "gwf-gwf" or "gwt-gwt" 

634 exchange_key = exchange_names[0].split("_")[1] 

635 cbc[exchange_key] = exchange_budgets 

636 return cbc 

637 

638 def _pad_missing_variables(self, cbc_per_partition: list[GridDataset]) -> None: 

639 """ 

640 Boundary conditions can be missing in certain partitions, as do their 

641 budgets, in which case we manually assign an empty grid of nans. 

642 """ 

643 dims_per_unique_key = { 

644 key: cbc[key].dims for cbc in cbc_per_partition for key in cbc.keys() 

645 } 

646 for cbc in cbc_per_partition: 

647 missing_keys = set(dims_per_unique_key.keys()) - set(cbc.keys()) 

648 

649 for missing in missing_keys: 

650 missing_dims = dims_per_unique_key[missing] 

651 missing_coords = {dim: cbc.coords[dim] for dim in missing_dims} 

652 

653 shape = tuple([len(missing_coords[dim]) for dim in missing_dims]) 

654 chunks = (1,) + shape[1:] 

655 missing_data = dask.array.full(shape, np.nan, chunks=chunks) 

656 

657 missing_grid = xr.DataArray( 

658 missing_data, dims=missing_dims, coords=missing_coords 

659 ) 

660 if isinstance(cbc, xu.UgridDataset): 

661 missing_grid = xu.UgridDataArray( 

662 missing_grid, 

663 grid=cbc.ugrid.grid, 

664 ) 

665 cbc[missing] = missing_grid 

666 

667 def _merge_budgets( 

668 self, modelnames: list[str], output: str, **settings 

669 ) -> GridDataset: 

670 if settings["flowja"] is True: 

671 raise ValueError("``flowja`` cannot be set to True when merging budgets.") 

672 

673 cbc_per_partition = [] 

674 for modelname in modelnames: 

675 cbc = self._open_single_output_single_model(modelname, output, **settings) 

676 # Merge and assign exchange budgets to dataset 

677 # FUTURE: Refactor to insert these exchange budgets in horizontal 

678 # flows. 

679 cbc = self._merge_and_assign_exchange_budgets(cbc) 

680 if not is_unstructured(cbc): 

681 cbc = cbc.where(self[modelname].domain, other=np.nan) 

682 cbc_per_partition.append(cbc) 

683 

684 self._pad_missing_variables(cbc_per_partition) 

685 

686 return merge_partitions(cbc_per_partition) 

687 

688 def _concat_species( 

689 self, output: str, species_ls: Optional[list[str]] = None, **settings 

690 ) -> GridDataArray | GridDataset: 

691 # groupby flow model, to somewhat enforce consistent transport model 

692 # ordening. Say: 

693 # F = Flow model, T = Transport model 

694 # a = species "a", b = species "b" 

695 # 1 = partition 1, 2 = partition 2 

696 # then this: 

697 # F1Ta1 F1Tb1 F2Ta2 F2Tb2 -> F1: [Ta1, Tb1], F2: [Ta2, Tb2] 

698 # F1Ta1 F2Tb1 F1Ta1 F2Tb2 -> F1: [Ta1, Tb1], F2: [Ta2, Tb2] 

699 tpt_models_per_flow_model = self._get_transport_models_per_flow_model() 

700 all_tpt_names = list(tpt_models_per_flow_model.values()) 

701 

702 # [[Ta_1, Tb_1], [Ta_2, Tb_2]] -> [[Ta_1, Ta_2], [Tb_1, Tb_2]] 

703 # [[Ta, Tb]] -> [[Ta], [Tb]] 

704 tpt_names_per_species = list(zip(*all_tpt_names)) 

705 

706 if self.is_split(): 

707 # [[Ta_1, Tb_1], [Ta_2, Tb_2]] -> [Ta, Tb] 

708 unpartitioned_modelnames = [ 

709 tpt_name.rpartition("_")[0] for tpt_name in all_tpt_names[0] 

710 ] 

711 else: 

712 # [[Ta, Tb]] -> [Ta, Tb] 

713 unpartitioned_modelnames = all_tpt_names[0] 

714 

715 if not species_ls: 

716 species_ls = unpartitioned_modelnames 

717 

718 if len(species_ls) != len(tpt_names_per_species): 

719 raise ValueError( 

720 "species_ls does not equal the number of transport models, " 

721 f"expected length {len(tpt_names_per_species)}, received {species_ls}" 

722 ) 

723 

724 if len(species_ls) == 1: 

725 return self._open_single_output( 

726 list(tpt_names_per_species[0]), output, **settings 

727 ) 

728 

729 # Concatenate species 

730 outputs = [] 

731 for species, tpt_names in zip(species_ls, tpt_names_per_species): 

732 output_data = self._open_single_output(list(tpt_names), output, **settings) 

733 output_data = output_data.assign_coords(species=species) 

734 outputs.append(output_data) 

735 return concat(outputs, dim="species") 

736 

737 def _open_single_output_single_model( 

738 self, modelname: str, output: str, **settings 

739 ) -> GridDataArray | GridDataset: 

740 """ 

741 Opens single output of single model 

742 

743 Parameters 

744 ---------- 

745 modelname: str 

746 Name of groundwater model from which output should be read. 

747 output: str 

748 Output variable name to open. 

749 **settings: 

750 Extra settings that need to be passed through to the respective 

751 output function. 

752 """ 

753 open_func = OUTPUT_FUNC_MAPPING[output] 

754 expected_modeltype = OUTPUT_MODEL_MAPPING[output] 

755 

756 if self.directory is None: 

757 raise RuntimeError(f"Simulation {self.name} has not been written yet.") 

758 model_path = self.directory / modelname 

759 

760 # Get model 

761 model = self[modelname] 

762 if not isinstance(model, expected_modeltype): 

763 raise TypeError( 

764 f"{modelname} not a {expected_modeltype}, instead got {type(model)}" 

765 ) 

766 # Get output file path 

767 oc_key = model._get_pkgkey("oc") 

768 oc_pkg = model[oc_key] 

769 # Ensure "-transport" and "-flow" are stripped from "budget" 

770 oc_output = output.split("-")[0] 

771 output_path = oc_pkg._get_output_filepath(model_path, oc_output) 

772 # Force path to always include simulation directory. 

773 output_path = self.directory / output_path 

774 

775 grb_path = self._get_grb_path(modelname) 

776 

777 if not output_path.exists(): 

778 raise RuntimeError( 

779 f"Could not find output in {output_path}, check if you already ran simulation {self.name}" 

780 ) 

781 

782 return open_func(output_path, grb_path, **settings) 

783 

784 def _get_flow_modelname_coupled_to_transport_model( 

785 self, transport_modelname: str 

786 ) -> str: 

787 """ 

788 Get name of flow model coupled to transport model, throws error if 

789 multiple flow models are couple to 1 transport model. 

790 """ 

791 exchanges = self.get_exchange_relationships() 

792 coupled_flow_models = [ 

793 i[2] 

794 for i in exchanges 

795 if (i[3] == transport_modelname) & (i[0] == "GWF6-GWT6") 

796 ] 

797 if len(coupled_flow_models) != 1: 

798 raise ValueError( 

799 f"Exactly one flow model must be coupled to transport model {transport_modelname}, got: {coupled_flow_models}" 

800 ) 

801 return coupled_flow_models[0] 

802 

803 def _get_grb_path(self, modelname: str) -> Path: 

804 """ 

805 Finds appropriate grb path belonging to modelname. Grb files are not 

806 written for transport models, so this method always returns a path to a 

807 flowmodel. In case of a transport model, it returns the path to the grb 

808 file its coupled flow model. 

809 """ 

810 model = self[modelname] 

811 # Get grb path 

812 if isinstance(model, GroundwaterTransportModel): 

813 flow_model_name = self._get_flow_modelname_coupled_to_transport_model( 

814 modelname 

815 ) 

816 flow_model_path = self.directory / flow_model_name 

817 else: 

818 flow_model_path = self.directory / modelname 

819 

820 diskey = model._get_diskey() 

821 dis_id = model[diskey]._pkg_id 

822 return flow_model_path / f"{diskey}.{dis_id}.grb" 

823 

824 @standard_log_decorator() 

825 def dump( 

826 self, 

827 directory=".", 

828 validate: bool = True, 

829 mdal_compliant: bool = False, 

830 crs=None, 

831 ) -> None: 

832 """ 

833 Dump simulation to files. Writes a model definition as .TOML file, which 

834 points to data for each package. Each package is stored as a separate 

835 NetCDF. Structured grids are saved as regular NetCDFs, unstructured 

836 grids are saved as UGRID NetCDF. Structured grids are always made GDAL 

837 compliant, unstructured grids can be made MDAL compliant optionally. 

838 

839 Parameters 

840 ---------- 

841 directory: str or Path, optional 

842 directory to dump simulation into. Defaults to current working directory. 

843 validate: bool, optional 

844 Whether to validate simulation data. Defaults to True. 

845 mdal_compliant: bool, optional 

846 Convert data with 

847 :func:`imod.prepare.spatial.mdal_compliant_ugrid2d` to MDAL 

848 compliant unstructured grids. Defaults to False. 

849 crs: Any, optional 

850 Anything accepted by rasterio.crs.CRS.from_user_input 

851 Requires ``rioxarray`` installed. 

852 """ 

853 directory = pathlib.Path(directory) 

854 directory.mkdir(parents=True, exist_ok=True) 

855 

856 toml_content: DefaultDict[str, dict] = collections.defaultdict(dict) 

857 for key, value in self.items(): 

858 cls_name = type(value).__name__ 

859 if isinstance(value, Modflow6Model): 

860 model_toml_path = value.dump( 

861 directory, key, validate, mdal_compliant, crs 

862 ) 

863 toml_content[cls_name][key] = model_toml_path.relative_to( 

864 directory 

865 ).as_posix() 

866 elif key in ["gwtgwf_exchanges", "split_exchanges"]: 

867 toml_content[key] = collections.defaultdict(list) 

868 for exchange_package in self[key]: 

869 exchange_type, filename, _, _ = exchange_package.get_specification() 

870 exchange_class_short = type(exchange_package).__name__ 

871 path = f"{filename}.nc" 

872 exchange_package.dataset.to_netcdf(directory / path) 

873 toml_content[key][exchange_class_short].append(path) 

874 

875 else: 

876 path = f"{key}.nc" 

877 value.dataset.to_netcdf(directory / path) 

878 toml_content[cls_name][key] = path 

879 

880 with open(directory / f"{self.name}.toml", "wb") as f: 

881 tomli_w.dump(toml_content, f) 

882 

883 return 

884 

885 @staticmethod 

886 def from_file(toml_path): 

887 classes = { 

888 item_cls.__name__: item_cls 

889 for item_cls in ( 

890 GroundwaterFlowModel, 

891 GroundwaterTransportModel, 

892 imod.mf6.TimeDiscretization, 

893 imod.mf6.Solution, 

894 imod.mf6.GWFGWF, 

895 imod.mf6.GWFGWT, 

896 imod.mf6.GWTGWT, 

897 ) 

898 } 

899 

900 toml_path = pathlib.Path(toml_path) 

901 with open(toml_path, "rb") as f: 

902 toml_content = tomli.load(f) 

903 

904 simulation = Modflow6Simulation(name=toml_path.stem) 

905 for key, entry in toml_content.items(): 

906 if key not in ["gwtgwf_exchanges", "split_exchanges"]: 

907 item_cls = classes[key] 

908 for name, filename in entry.items(): 

909 path = toml_path.parent / filename 

910 simulation[name] = item_cls.from_file(path) 

911 else: 

912 simulation[key] = [] 

913 for exchange_class, exchange_list in entry.items(): 

914 item_cls = classes[exchange_class] 

915 for filename in exchange_list: 

916 path = toml_path.parent / filename 

917 simulation[key].append(item_cls.from_file(path)) 

918 

919 return simulation 

920 

921 def get_exchange_relationships(self): 

922 result = [] 

923 

924 if "gwtgwf_exchanges" in self: 

925 for exchange in self["gwtgwf_exchanges"]: 

926 result.append(exchange.get_specification()) 

927 

928 # exchange for splitting models 

929 if self.is_split(): 

930 for exchange in self["split_exchanges"]: 

931 result.append(exchange.get_specification()) 

932 return result 

933 

934 def get_models_of_type(self, model_id) -> dict[str, IModel]: 

935 return { 

936 k: v 

937 for k, v in self.items() 

938 if isinstance(v, Modflow6Model) and (v.model_id == model_id) 

939 } 

940 

941 def get_models(self): 

942 return {k: v for k, v in self.items() if isinstance(v, Modflow6Model)} 

943 

944 def clip_box( 

945 self, 

946 time_min: Optional[cftime.datetime | np.datetime64 | str] = None, 

947 time_max: Optional[cftime.datetime | np.datetime64 | str] = None, 

948 layer_min: Optional[int] = None, 

949 layer_max: Optional[int] = None, 

950 x_min: Optional[float] = None, 

951 x_max: Optional[float] = None, 

952 y_min: Optional[float] = None, 

953 y_max: Optional[float] = None, 

954 states_for_boundary: Optional[dict[str, GridDataArray]] = None, 

955 ) -> Modflow6Simulation: 

956 """ 

957 Clip a simulation by a bounding box (time, layer, y, x). 

958 

959 Slicing intervals may be half-bounded, by providing None: 

960 

961 * To select 500.0 <= x <= 1000.0: 

962 ``clip_box(x_min=500.0, x_max=1000.0)``. 

963 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)`` 

964 or ``clip_box(x_max=1000.0)``. 

965 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)`` 

966 or ``clip_box(x_min=1000.0)``. 

967 

968 Parameters 

969 ---------- 

970 time_min: optional 

971 time_max: optional 

972 layer_min: optional, int 

973 layer_max: optional, int 

974 x_min: optional, float 

975 x_max: optional, float 

976 y_min: optional, float 

977 y_max: optional, float 

978 states_for_boundary : optional, Dict[pkg_name:str, boundary_values:Union[xr.DataArray, xu.UgridDataArray]] 

979 

980 Returns 

981 ------- 

982 clipped : Simulation 

983 """ 

984 

985 if self.is_split(): 

986 raise RuntimeError( 

987 "Unable to clip simulation. Clipping can only be done on simulations that haven't been split." 

988 + "Therefore clipping should be done before splitting the simulation." 

989 ) 

990 if not self.has_one_flow_model(): 

991 raise ValueError( 

992 "Unable to clip simulation. Clipping can only be done on simulations that have a single flow model ." 

993 ) 

994 for model_name, model in self.get_models().items(): 

995 supported, error_with_object = model.is_clipping_supported() 

996 if not supported: 

997 raise ValueError( 

998 f"simulation cannot be clipped due to presence of package '{error_with_object}' in model '{model_name}'" 

999 ) 

1000 

1001 clipped = type(self)(name=self.name) 

1002 for key, value in self.items(): 

1003 state_for_boundary = ( 

1004 None if states_for_boundary is None else states_for_boundary.get(key) 

1005 ) 

1006 if isinstance(value, Modflow6Model): 

1007 clipped[key] = value.clip_box( 

1008 time_min=time_min, 

1009 time_max=time_max, 

1010 layer_min=layer_min, 

1011 layer_max=layer_max, 

1012 x_min=x_min, 

1013 x_max=x_max, 

1014 y_min=y_min, 

1015 y_max=y_max, 

1016 state_for_boundary=state_for_boundary, 

1017 ) 

1018 elif isinstance(value, Package): 

1019 clipped[key] = value.clip_box( 

1020 time_min=time_min, 

1021 time_max=time_max, 

1022 layer_min=layer_min, 

1023 layer_max=layer_max, 

1024 x_min=x_min, 

1025 x_max=x_max, 

1026 y_min=y_min, 

1027 y_max=y_max, 

1028 ) 

1029 else: 

1030 raise ValueError(f"object of type {type(value)} cannot be clipped.") 

1031 return clipped 

1032 

1033 def split(self, submodel_labels: GridDataArray) -> Modflow6Simulation: 

1034 """ 

1035 Split a simulation in different partitions using a submodel_labels array. 

1036 

1037 The submodel_labels array defines how a simulation will be split. The array should have the same topology as 

1038 the domain being split i.e. similar shape as a layer in the domain. The values in the array indicate to 

1039 which partition a cell belongs. The values should be zero or greater. 

1040 

1041 The method return a new simulation containing all the split models and packages 

1042 """ 

1043 if self.is_split(): 

1044 raise RuntimeError( 

1045 "Unable to split simulation. Splitting can only be done on simulations that haven't been split." 

1046 ) 

1047 

1048 if not self.has_one_flow_model(): 

1049 raise ValueError( 

1050 "splitting of simulations with more (or less) than 1 flow model currently not supported." 

1051 ) 

1052 transport_models = self.get_models_of_type("gwt6") 

1053 flow_models = self.get_models_of_type("gwf6") 

1054 if not any(flow_models) and not any(transport_models): 

1055 raise ValueError("a simulation without any models cannot be split.") 

1056 

1057 original_models = {**flow_models, **transport_models} 

1058 for model_name, model in original_models.items(): 

1059 supported, error_with_object = model.is_splitting_supported() 

1060 if not supported: 

1061 raise ValueError( 

1062 f"simulation cannot be split due to presence of package '{error_with_object}' in model '{model_name}'" 

1063 ) 

1064 

1065 original_packages = get_packages(self) 

1066 

1067 partition_info = create_partition_info(submodel_labels) 

1068 

1069 exchange_creator: ExchangeCreator_Unstructured | ExchangeCreator_Structured 

1070 if is_unstructured(submodel_labels): 

1071 exchange_creator = ExchangeCreator_Unstructured( 

1072 submodel_labels, partition_info 

1073 ) 

1074 else: 

1075 exchange_creator = ExchangeCreator_Structured( 

1076 submodel_labels, partition_info 

1077 ) 

1078 

1079 new_simulation = imod.mf6.Modflow6Simulation(f"{self.name}_partioned") 

1080 for package_name, package in {**original_packages}.items(): 

1081 new_simulation[package_name] = deepcopy(package) 

1082 

1083 for model_name, model in original_models.items(): 

1084 solution_name = self.get_solution_name(model_name) 

1085 new_simulation[solution_name].remove_model_from_solution(model_name) 

1086 for submodel_partition_info in partition_info: 

1087 new_model_name = f"{model_name}_{submodel_partition_info.id}" 

1088 new_simulation[new_model_name] = slice_model( 

1089 submodel_partition_info, model 

1090 ) 

1091 new_simulation[solution_name].add_model_to_solution(new_model_name) 

1092 

1093 exchanges: list[Any] = [] 

1094 

1095 for flow_model_name, flow_model in flow_models.items(): 

1096 exchanges += exchange_creator.create_gwfgwf_exchanges( 

1097 flow_model_name, flow_model.domain.layer 

1098 ) 

1099 

1100 if any(transport_models): 

1101 for tpt_model_name in transport_models: 

1102 exchanges += exchange_creator.create_gwtgwt_exchanges( 

1103 tpt_model_name, flow_model_name, model.domain.layer 

1104 ) 

1105 new_simulation._add_modelsplit_exchanges(exchanges) 

1106 new_simulation._update_buoyancy_packages() 

1107 new_simulation._set_flow_exchange_options() 

1108 new_simulation._set_transport_exchange_options() 

1109 new_simulation._update_ssm_packages() 

1110 

1111 new_simulation._filter_inactive_cells_from_exchanges() 

1112 return new_simulation 

1113 

1114 def regrid_like( 

1115 self, 

1116 regridded_simulation_name: str, 

1117 target_grid: GridDataArray, 

1118 validate: bool = True, 

1119 ) -> "Modflow6Simulation": 

1120 """ 

1121 This method creates a new simulation object. The models contained in the new simulation are regridded versions 

1122 of the models in the input object (this). 

1123 Time discretization and solver settings are copied. 

1124 

1125 Parameters 

1126 ---------- 

1127 regridded_simulation_name: str 

1128 name given to the output simulation 

1129 target_grid: xr.DataArray or xu.UgridDataArray 

1130 discretization onto which the models in this simulation will be regridded 

1131 validate: bool 

1132 set to true to validate the regridded packages 

1133 

1134 Returns 

1135 ------- 

1136 a new simulation object with regridded models 

1137 """ 

1138 

1139 return _regrid_like(self, regridded_simulation_name, target_grid, validate) 

1140 

1141 def _add_modelsplit_exchanges(self, exchanges_list: list[GWFGWF]) -> None: 

1142 if not self.is_split(): 

1143 self["split_exchanges"] = [] 

1144 self["split_exchanges"].extend(exchanges_list) 

1145 

1146 def _set_flow_exchange_options(self) -> None: 

1147 # collect some options that we will auto-set 

1148 for exchange in self["split_exchanges"]: 

1149 if isinstance(exchange, GWFGWF): 

1150 model_name_1 = exchange.dataset["model_name_1"].values[()] 

1151 model_1 = self[model_name_1] 

1152 exchange.set_options( 

1153 save_flows=model_1["oc"].is_budget_output, 

1154 dewatered=model_1["npf"].is_dewatered, 

1155 variablecv=model_1["npf"].is_variable_vertical_conductance, 

1156 xt3d=model_1["npf"].get_xt3d_option(), 

1157 newton=model_1.is_use_newton(), 

1158 ) 

1159 

1160 def _set_transport_exchange_options(self) -> None: 

1161 for exchange in self["split_exchanges"]: 

1162 if isinstance(exchange, GWTGWT): 

1163 model_name_1 = exchange.dataset["model_name_1"].values[()] 

1164 model_1 = self[model_name_1] 

1165 advection_key = model_1._get_pkgkey("adv") 

1166 dispersion_key = model_1._get_pkgkey("dsp") 

1167 

1168 scheme = None 

1169 xt3d_off = None 

1170 xt3d_rhs = None 

1171 if advection_key is not None: 

1172 scheme = model_1[advection_key].dataset["scheme"].values[()] 

1173 if dispersion_key is not None: 

1174 xt3d_off = model_1[dispersion_key].dataset["xt3d_off"].values[()] 

1175 xt3d_rhs = model_1[dispersion_key].dataset["xt3d_rhs"].values[()] 

1176 exchange.set_options( 

1177 save_flows=model_1["oc"].is_budget_output, 

1178 adv_scheme=scheme, 

1179 dsp_xt3d_off=xt3d_off, 

1180 dsp_xt3d_rhs=xt3d_rhs, 

1181 ) 

1182 

1183 def _filter_inactive_cells_from_exchanges(self) -> None: 

1184 for ex in self["split_exchanges"]: 

1185 for i in [1, 2]: 

1186 self._filter_inactive_cells_exchange_domain(ex, i) 

1187 

1188 def _filter_inactive_cells_exchange_domain(self, ex: GWFGWF, i: int) -> None: 

1189 """Filters inactive cells from one exchange domain inplace""" 

1190 modelname = ex[f"model_name_{i}"].values[()] 

1191 domain = self[modelname].domain 

1192 

1193 layer = ex.dataset["layer"] - 1 

1194 id = ex.dataset[f"cell_id{i}"] - 1 

1195 if is_unstructured(domain): 

1196 exchange_cells = { 

1197 "layer": layer, 

1198 "mesh2d_nFaces": id, 

1199 } 

1200 else: 

1201 exchange_cells = { 

1202 "layer": layer, 

1203 "y": id.sel({f"cell_dims{i}": f"row_{i}"}), 

1204 "x": id.sel({f"cell_dims{i}": f"column_{i}"}), 

1205 } 

1206 exchange_domain = domain.isel(exchange_cells) 

1207 active_exchange_domain = exchange_domain.where(exchange_domain.values > 0) 

1208 active_exchange_domain = active_exchange_domain.dropna("index") 

1209 ex.dataset = ex.dataset.sel(index=active_exchange_domain["index"]) 

1210 

1211 def get_solution_name(self, model_name: str) -> Optional[str]: 

1212 for k, v in self.items(): 

1213 if isinstance(v, Solution): 

1214 if model_name in v.dataset["modelnames"]: 

1215 return k 

1216 return None 

1217 

1218 def __repr__(self) -> str: 

1219 typename = type(self).__name__ 

1220 INDENT = " " 

1221 attrs = [ 

1222 f"{typename}(", 

1223 f"{INDENT}name={repr(self.name)},", 

1224 f"{INDENT}directory={repr(self.directory)}", 

1225 ] 

1226 items = [ 

1227 f"{INDENT}{repr(key)}: {type(value).__name__}," 

1228 for key, value in self.items() 

1229 ] 

1230 # Place the emtpy dict on the same line. Looks silly otherwise. 

1231 if items: 

1232 content = attrs + ["){"] + items + ["}"] 

1233 else: 

1234 content = attrs + ["){}"] 

1235 return "\n".join(content) 

1236 

1237 def _get_transport_models_per_flow_model(self) -> dict[str, list[str]]: 

1238 flow_models = self.get_models_of_type("gwf6") 

1239 transport_models = self.get_models_of_type("gwt6") 

1240 # exchange for flow and transport 

1241 result = collections.defaultdict(list) 

1242 

1243 for flow_model_name in flow_models: 

1244 flow_model = self[flow_model_name] 

1245 for tpt_model_name in transport_models: 

1246 tpt_model = self[tpt_model_name] 

1247 if is_equal(tpt_model.domain, flow_model.domain): 

1248 result[flow_model_name].append(tpt_model_name) 

1249 return result 

1250 

1251 def _generate_gwfgwt_exchanges(self) -> list[GWFGWT]: 

1252 exchanges = [] 

1253 flow_transport_mapping = self._get_transport_models_per_flow_model() 

1254 for flow_name, tpt_models_of_flow_model in flow_transport_mapping.items(): 

1255 if len(tpt_models_of_flow_model) > 0: 

1256 for transport_model_name in tpt_models_of_flow_model: 

1257 exchanges.append(GWFGWT(flow_name, transport_model_name)) 

1258 

1259 return exchanges 

1260 

1261 def _update_ssm_packages(self) -> None: 

1262 flow_transport_mapping = self._get_transport_models_per_flow_model() 

1263 for flow_name, tpt_models_of_flow_model in flow_transport_mapping.items(): 

1264 flow_model = self[flow_name] 

1265 for tpt_model_name in tpt_models_of_flow_model: 

1266 tpt_model = self[tpt_model_name] 

1267 ssm_key = tpt_model._get_pkgkey("ssm") 

1268 if ssm_key is not None: 

1269 old_ssm_package = tpt_model.pop(ssm_key) 

1270 state_variable_name = old_ssm_package.dataset[ 

1271 "auxiliary_variable_name" 

1272 ].values[0] 

1273 ssm_package = SourceSinkMixing.from_flow_model( 

1274 flow_model, state_variable_name, is_split=self.is_split() 

1275 ) 

1276 if ssm_package is not None: 

1277 tpt_model[ssm_key] = ssm_package 

1278 

1279 def _update_buoyancy_packages(self) -> None: 

1280 flow_transport_mapping = self._get_transport_models_per_flow_model() 

1281 for flow_name, tpt_models_of_flow_model in flow_transport_mapping.items(): 

1282 flow_model = self[flow_name] 

1283 flow_model.update_buoyancy_package(tpt_models_of_flow_model) 

1284 

1285 def is_split(self) -> bool: 

1286 return "split_exchanges" in self.keys() 

1287 

1288 def has_one_flow_model(self) -> bool: 

1289 flow_models = self.get_models_of_type("gwf6") 

1290 return len(flow_models) == 1 

1291 

1292 def mask_all_models( 

1293 self, 

1294 mask: GridDataArray, 

1295 ): 

1296 """ 

1297 This function applies a mask to all models in a simulation, provided they use 

1298 the same discretization. The method parameter "mask" is an idomain-like array. 

1299 Masking will overwrite idomain with the mask where the mask is 0 or -1. 

1300 Where the mask is 1, the original value of idomain will be kept. 

1301 Masking will update the packages accordingly, blanking their input where needed, 

1302 and is therefore not a reversible operation. 

1303 

1304 Parameters 

1305 ---------- 

1306 mask: xr.DataArray, xu.UgridDataArray of ints 

1307 idomain-like integer array. 1 sets cells to active, 0 sets cells to inactive, 

1308 -1 sets cells to vertical passthrough 

1309 """ 

1310 _mask_all_models(self, mask)