Skip to content

Commit

Permalink
Refactor drivercoupling (#273)
Browse files Browse the repository at this point in the history
* Refactor into CoupledModels -> DriverCoupling -> Mappings

* More refactoring

* More work

* Remove dataclasses for driver coupling classes, use pydantic basemodel instead

* Cleanup mapping types

* Add arbitrary_types_allowed for pydantic and geodataframes

* Fix ribasim type hints

* Avoid index kwarg type error

* Export DriverCoupling.
Move mappings one level up for ease of import in tests

* Get test_primod running again

* Update cases

* Fixes for pre-processing tests

* DriverCoupling name changes

* Tests running locally

* Create SvatUserDemandMapping class
Try to improve logic a bit, probably WIP...

* Please mypy

* Address review comments

* lower dir names. Fixes #228
  • Loading branch information
Huite committed Mar 21, 2024
1 parent d6ab10b commit ac5eae5
Show file tree
Hide file tree
Showing 36 changed files with 1,722 additions and 1,684 deletions.
16 changes: 15 additions & 1 deletion pre-processing/primod/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
from primod.driver_coupling import (
MetaModDriverCoupling,
RibaMetaDriverCoupling,
RibaModActiveDriverCoupling,
RibaModPassiveDriverCoupling,
)
from primod.metamod import MetaMod
from primod.ribametamod import RibaMetaMod
from primod.ribamod import RibaMod

__all__ = ["MetaMod", "RibaMod", "RibaMetaMod"]
__all__ = (
"MetaMod",
"RibaMod",
"RibaMetaMod",
"MetaModDriverCoupling",
"RibaMetaDriverCoupling",
"RibaModActiveDriverCoupling",
"RibaModPassiveDriverCoupling",
)

__version__ = "2024.2.1"
53 changes: 53 additions & 0 deletions pre-processing/primod/coupled_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import abc
from collections.abc import Sequence
from pathlib import Path
from typing import Any

from primod.driver_coupling.driver_coupling_base import DriverCoupling


class CoupledModel(abc.ABC):
coupling_list: Sequence[DriverCoupling]

@abc.abstractmethod
def write(self, directory: str | Path, *args: Any, **kwargs: Any) -> None:
pass

@abc.abstractmethod
def write_toml(self, directory: str | Path, *args: Any, **kwargs: Any) -> None:
pass

@staticmethod
def _merge_coupling_dicts(dicts: list[dict[str, Any]]) -> dict[str, Any]:
coupling_dict: dict[str, dict[str, Any] | Any] = {}
for top_dict in dicts:
for top_key, top_value in top_dict.items():
if isinstance(top_value, dict):
if top_key not in coupling_dict:
coupling_dict[top_key] = {}
for key, filename in top_value.items():
coupling_dict[top_key][key] = filename
else:
coupling_dict[top_key] = top_value
return coupling_dict

def write_exchanges(self, directory: str | Path) -> dict[str, Any]:
"""
Write exchanges and return their filenames for the coupler
configuration file.
"""
directory = Path(directory)
exchange_dir = Path(directory) / "exchanges"
exchange_dir.mkdir(exist_ok=True, parents=True)

coupling_dicts = []
for coupling in self.coupling_list:
coupling_dict = coupling.write_exchanges(
directory=exchange_dir, coupled_model=self
)
coupling_dicts.append(coupling_dict)

# FUTURE: if we support multiple MF6 models, group them by name before
# merging, and return a list of coupling_dicts.
merged_coupling_dict = self._merge_coupling_dicts(coupling_dicts)
return merged_coupling_dict
13 changes: 13 additions & 0 deletions pre-processing/primod/driver_coupling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from primod.driver_coupling.metamod import MetaModDriverCoupling
from primod.driver_coupling.ribameta import RibaMetaDriverCoupling
from primod.driver_coupling.ribamod import (
RibaModActiveDriverCoupling,
RibaModPassiveDriverCoupling,
)

__all__ = (
"MetaModDriverCoupling",
"RibaMetaDriverCoupling",
"RibaModActiveDriverCoupling",
"RibaModPassiveDriverCoupling",
)
22 changes: 22 additions & 0 deletions pre-processing/primod/driver_coupling/driver_coupling_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import abc
from pathlib import Path
from typing import Any

from pydantic import BaseModel


class DriverCoupling(BaseModel, abc.ABC):
"""
Abstract base class for driver couplings.
"""

# Config required for e.g. geodataframes
model_config = {"arbitrary_types_allowed": True}

@abc.abstractmethod
def derive_mapping(self, *args: Any, **kwargs: Any) -> Any:
pass

@abc.abstractmethod
def write_exchanges(self, directory: Path, coupled_model: Any) -> dict[str, Any]:
pass
110 changes: 110 additions & 0 deletions pre-processing/primod/driver_coupling/metamod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from pathlib import Path
from typing import Any

from imod.mf6 import GroundwaterFlowModel
from imod.msw import GridData, MetaSwapModel, Sprinkling

from primod.driver_coupling.driver_coupling_base import DriverCoupling
from primod.mapping.node_svat_mapping import NodeSvatMapping
from primod.mapping.rch_svat_mapping import RechargeSvatMapping
from primod.mapping.wel_svat_mapping import WellSvatMapping


class MetaModDriverCoupling(DriverCoupling):
"""
Attributes
----------
mf6_model : str
The model of the driver.
mf6_recharge_package: str
Key of Modflow 6 recharge package to which MetaSWAP is coupled.
mf6_wel_package: str or None
Optional key of Modflow 6 well package to which MetaSWAP sprinkling is
coupled.
"""

mf6_model: str
mf6_recharge_package: str
mf6_wel_package: str | None = None

def _check_sprinkling(
self, msw_model: MetaSwapModel, gwf_model: GroundwaterFlowModel
) -> bool:
sprinkling_key = msw_model._get_pkg_key(Sprinkling, optional_package=True)
sprinkling_in_msw = sprinkling_key is not None
sprinkling_in_mf6 = self.mf6_wel_package in gwf_model.keys()

value = False
match (sprinkling_in_msw, sprinkling_in_mf6):
case (True, False):
raise ValueError(
f"No package named {self.mf6_wel_package} found in Modflow 6 model, "
"but Sprinkling package found in MetaSWAP. "
"iMOD Coupler requires a Well Package "
"to couple wells."
)
case (False, True):
raise ValueError(
f"Modflow 6 Well package {self.mf6_wel_package} specified for sprinkling, "
"but no Sprinkling package found in MetaSWAP model."
)
case (True, True):
value = True
case (False, False):
value = False

return value

def derive_mapping(
self, msw_model: MetaSwapModel, gwf_model: GroundwaterFlowModel
) -> tuple[NodeSvatMapping, RechargeSvatMapping, WellSvatMapping | None]:
if self.mf6_recharge_package not in gwf_model.keys():
raise ValueError(
f"No package named {self.mf6_recharge_package} detected in Modflow 6 model. "
"iMOD_coupler requires a Recharge package."
)

grid_data_key = [
pkgname for pkgname, pkg in msw_model.items() if isinstance(pkg, GridData)
][0]

dis = gwf_model[gwf_model._get_pkgkey("dis")]

index, svat = msw_model[grid_data_key].generate_index_array()
grid_mapping = NodeSvatMapping(svat=svat, modflow_dis=dis, index=index)

recharge = gwf_model[self.mf6_recharge_package]

rch_mapping = RechargeSvatMapping(svat, recharge, index=index)

if self._check_sprinkling(msw_model=msw_model, gwf_model=gwf_model):
well = gwf_model[self.mf6_wel_package]
well_mapping = WellSvatMapping(svat, well, index=index)
return grid_mapping, rch_mapping, well_mapping
else:
return grid_mapping, rch_mapping, None

def write_exchanges(self, directory: Path, coupled_model: Any) -> dict[str, Any]:
mf6_simulation = coupled_model.mf6_simulation
gwf_model = mf6_simulation[self.mf6_model]
msw_model = coupled_model.msw_model

grid_mapping, rch_mapping, well_mapping = self.derive_mapping(
msw_model=msw_model,
gwf_model=gwf_model,
)

coupling_dict: dict[str, Any] = {}
coupling_dict["mf6_model"] = self.mf6_model

coupling_dict["mf6_msw_node_map"] = grid_mapping.write(directory)
coupling_dict["mf6_msw_recharge_pkg"] = self.mf6_recharge_package
coupling_dict["mf6_msw_recharge_map"] = rch_mapping.write(directory)
coupling_dict["enable_sprinkling"] = False

if well_mapping is not None:
coupling_dict["enable_sprinkling"] = True
coupling_dict["mf6_msw_well_pkg"] = self.mf6_wel_package
coupling_dict["mf6_msw_sprinkling_map"] = well_mapping.write(directory)

return coupling_dict
148 changes: 148 additions & 0 deletions pre-processing/primod/driver_coupling/ribameta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import copy
from pathlib import Path
from typing import Any

import geopandas as gpd
import imod
import numpy as np
import ribasim
from imod.msw import GridData, MetaSwapModel, Sprinkling

from primod.driver_coupling.driver_coupling_base import DriverCoupling
from primod.driver_coupling.util import (
_nullify_ribasim_exchange_input,
_validate_node_ids,
)
from primod.mapping.svat_basin_mapping import SvatBasinMapping
from primod.mapping.svat_user_demand_mapping import SvatUserDemandMapping


class RibaMetaDriverCoupling(DriverCoupling):
"""A dataclass representing one coupling scenario for the RibaMod driver.
Attributes
----------
basin_definition: gpd.GeoDataFrame
GeoDataFrame of basin polygons
user_demand_definition: gpd.GeoDataFrame
GeoDataFrame of user demand polygons
"""

ribasim_basin_definition: gpd.GeoDataFrame
ribasim_user_demand_definition: gpd.GeoDataFrame | None = None

def _check_sprinkling(self, msw_model: MetaSwapModel) -> bool:
sprinkling_key = msw_model._get_pkg_key(Sprinkling, optional_package=True)
sprinkling_in_msw = sprinkling_key is not None
sprinkling_in_ribasim = self.ribasim_user_demand_definition is not None

if sprinkling_in_ribasim:
if sprinkling_in_msw:
return True
else:
raise ValueError(
"Ribasim UserDemand definition provided, "
"but no Sprinkling package found in MetaSWAP model."
)
else:
return False

def derive_mapping(
self,
ribasim_model: ribasim.Model,
msw_model: MetaSwapModel,
) -> tuple[SvatBasinMapping, SvatUserDemandMapping | None]:
grid_data_key = [
pkgname for pkgname, pkg in msw_model.items() if isinstance(pkg, GridData)
][0]

index, svat = msw_model[grid_data_key].generate_index_array()
basin_ids = _validate_node_ids(
ribasim_model.basin.node.df, self.ribasim_basin_definition
)
gridded_basin = imod.prepare.rasterize(
self.ribasim_basin_definition,
like=svat,
column="node_id",
)
svat_basin_mapping = SvatBasinMapping(
name="msw_ponding",
gridded_basin=gridded_basin,
basin_ids=basin_ids,
svat=svat,
index=index,
)

if self._check_sprinkling(msw_model=msw_model):
user_demand_ids = _validate_node_ids(
ribasim_model.user_demand.node.df, self.ribasim_user_demand_definition
)
gridded_user_demand = imod.prepare.rasterize(
self.ribasim_basin_definition,
like=svat,
column="node_id",
)
# sprinkling surface water for subsection of svats determined in 'sprinkling'
swspr_grid_data = copy.deepcopy(msw_model[grid_data_key])
nsu = swspr_grid_data.dataset["area"].sizes["subunit"]
swsprmax = msw_model["sprinkling"]
swspr_grid_data.dataset["area"].values = np.tile(
swsprmax["max_abstraction_surfacewater_m3_d"].values,
(nsu, 1, 1),
)
index_swspr, svat_swspr = swspr_grid_data.generate_index_array()
svat_user_demand_mapping = SvatUserDemandMapping(
name="msw_sw_sprinkling",
gridded_user_demand=gridded_user_demand,
user_demand_ids=user_demand_ids,
svat=svat_swspr,
index=index_swspr,
)
return svat_basin_mapping, svat_user_demand_mapping
else:
return svat_basin_mapping, None

def write_exchanges(self, directory: Path, coupled_model: Any) -> dict[str, Any]:
ribasim_model = coupled_model.ribasim_model
msw_model = coupled_model.msw_model

svat_basin_mapping, svat_user_demand_mapping = self.derive_mapping(
ribasim_model=ribasim_model,
msw_model=msw_model,
)

coupling_dict: dict[str, Any] = {}
coupling_dict["rib_msw_ponding_map_surface_water"] = svat_basin_mapping.write(
directory=directory
)

# Set Ribasim runoff input to Null for coupled basins
basin_ids = _validate_node_ids(
ribasim_model.basin.node.df, self.ribasim_basin_definition
)
coupled_basin_indices = svat_basin_mapping.dataframe["basin_index"]
coupled_basin_node_ids = basin_ids[coupled_basin_indices]
_nullify_ribasim_exchange_input(
ribasim_component=ribasim_model.basin,
coupled_node_ids=coupled_basin_node_ids,
columns=["runoff"],
)

# Now deal with sprinkling if set
if svat_user_demand_mapping is not None:
user_demand_ids = _validate_node_ids(
ribasim_model.user_demand.node.df, self.ribasim_user_demand_definition
)
coupling_dict["rib_msw_sprinkling_map_surface_water"] = (
svat_user_demand_mapping.write(directory=directory)
)
coupled_user_demand_indices = svat_user_demand_mapping.dataframe[
"user_demand_index"
]
coupled_user_demand_node_ids = user_demand_ids[coupled_user_demand_indices]
_nullify_ribasim_exchange_input(
ribasim_component=ribasim_model.user_demand,
coupled_node_ids=coupled_user_demand_node_ids,
columns=["demand"],
)
return coupling_dict
Loading

0 comments on commit ac5eae5

Please sign in to comment.