Skip to content

Commit

Permalink
Call update_bottom_minimum and refactor a little (#251)
Browse files Browse the repository at this point in the history
* Functional change: make sure to call update_bottom_minimum!

Refactoring: define water_level property on Drainage and River.
Define exchange_rib2mod and exchange_mod2rib.

* Disable mypy for this specific ChainMap...

* Remove property footgun

Using .water_level = ... calls the setter
But .water_level[:] does not!
  • Loading branch information
Huite authored Feb 9, 2024
1 parent dc6b1c9 commit bf88ab4
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 43 deletions.
83 changes: 48 additions & 35 deletions imod_coupler/drivers/ribamod/ribamod.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
from __future__ import annotations

import typing
from collections import ChainMap
from typing import Any

Expand All @@ -20,6 +21,10 @@
from imod_coupler.kernelwrappers.mf6_wrapper import Mf6Drainage, Mf6River, Mf6Wrapper
from imod_coupler.logging.exchange_collector import ExchangeCollector

# iMOD Python sets MODFLOW 6's time unit to days
# Ribasim's time unit is always seconds
RIBAMOD_TIME_FACTOR = 86400


class RibaMod(Driver):
"""The driver coupling Ribasim and MODFLOW 6"""
Expand Down Expand Up @@ -50,6 +55,7 @@ class RibaMod(Driver):
# ChainMaps
mf6_river_packages: ChainMap[str, Mf6River]
mf6_drainage_packages: ChainMap[str, Mf6Drainage]
mf6_active_packages = ChainMap[str, Mf6River | Mf6Drainage]

# Ribasim variables
ribasim_level: NDArray[Any]
Expand Down Expand Up @@ -128,6 +134,9 @@ def couple(self) -> None:
self.mf6_drainage_packages = ChainMap(
self.mf6_active_drainage_packages, self.mf6_passive_drainage_packages
)
self.mf6_active_packages = ChainMap(
self.mf6_active_river_packages, self.mf6_active_drainage_packages
) # type: ignore

# Get the level, drainage, infiltration from Ribasim
self.ribasim_infiltration = self.ribasim.get_value_ptr("infiltration")
Expand Down Expand Up @@ -194,63 +203,67 @@ def couple(self) -> None:

return

def update(self) -> None:
# iMOD Python sets MODFLOW 6' time unit to days
# Ribasim's time unit is always seconds
ribamod_time_factor = 86400

self.ribasim.update_subgrid_level()
# Ensure MODFLOW has river bottoms.
# Variables are otherwise initialized with zeros.
self.mf6.prepare_time_step(0.0)
# Set the MODFLOW 6 river stage and drainage to value of waterlevel of Ribasim basin
for key, river in self.mf6_active_river_packages.items():
new_stage = self.mask_rib2mod[key][:] * river.stage[:] + self.map_rib2mod[
key
].dot(self.subgrid_level)
river.set_stage(new_stage=new_stage)
for key, drainage in self.mf6_active_drainage_packages.items():
new_elevation = self.mask_rib2mod[key][:] * drainage.elevation[
:
] + self.map_rib2mod[key].dot(self.subgrid_level)
drainage.set_elevation(new_elevation=new_elevation)

# One time step in MODFLOW 6
# convergence loop
self.mf6.prepare_solve(1)
for kiter in range(1, self.max_iter + 1):
has_converged = self.do_iter(1)
if has_converged:
logger.debug(f"MF6-Ribasim converged in {kiter} iterations")
break
self.mf6.finalize_solve(1)
self.mf6.finalize_time_step()
@typing.no_type_check
def exchange_rib2mod(self) -> None:
# Mypy refuses to understand this ChainMap for some reason.
# ChainMaps work fine in other places...
for key, package in self.mf6_active_packages.items():
package.update_bottom_minimum()
package.set_water_level(
self.mask_rib2mod[key] * package.water_level
+ self.map_rib2mod[key].dot(self.subgrid_level)
)
return

def exchange_mod2rib(self) -> None:
# Zero the accumulator arrays
self.work_infiltration[:] = 0.0
self.work_drainage[:] = 0.0

# Compute MODFLOW 6 river and drain flux
for key, river in self.mf6_river_packages.items():
river_flux = river.get_flux(self.mf6_head)
ribasim_flux = self.map_mod2rib[key].dot(river_flux) / ribamod_time_factor
ribasim_flux = self.map_mod2rib[key].dot(river_flux) / RIBAMOD_TIME_FACTOR
self.work_infiltration += np.where(ribasim_flux > 0, ribasim_flux, 0)
self.work_drainage += np.where(ribasim_flux < 0, -ribasim_flux, 0)

for key, drainage in self.mf6_drainage_packages.items():
drain_flux = drainage.get_flux(self.mf6_head)
ribasim_flux = self.map_mod2rib[key].dot(drain_flux) / ribamod_time_factor
ribasim_flux = self.map_mod2rib[key].dot(drain_flux) / RIBAMOD_TIME_FACTOR
self.work_drainage -= ribasim_flux

# Set the infiltration and drainage to the actively coupled basins.
self.ribasim_drainage[self.coupled_mod2rib] = self.work_drainage[
self.coupled_mod2rib
]
self.ribasim_infiltration[self.coupled_mod2rib] = self.work_infiltration[
self.coupled_mod2rib
]
return

def update(self) -> None:
self.ribasim.update_subgrid_level()
# Ensure MODFLOW has river bottoms.
# Variables are otherwise initialized with zeros.
self.mf6.prepare_time_step(0.0)
# Set the MODFLOW 6 river stage and drainage to value of waterlevel of Ribasim basin
self.exchange_rib2mod()

# One time step in MODFLOW 6
# convergence loop
self.mf6.prepare_solve(1)
for kiter in range(1, self.max_iter + 1):
has_converged = self.do_iter(1)
if has_converged:
logger.debug(f"MF6-Ribasim converged in {kiter} iterations")
break
self.mf6.finalize_solve(1)
self.mf6.finalize_time_step()

# Set the infiltration and drainage to the coupled basins.
self.exchange_mod2rib()

# Update Ribasim until current time of MODFLOW 6
self.ribasim.update_until(self.mf6.get_current_time() * ribamod_time_factor)
self.ribasim.update_until(self.mf6.get_current_time() * RIBAMOD_TIME_FACTOR)

def do_iter(self, sol_id: int) -> bool:
"""Execute a single iteration"""
Expand Down
31 changes: 25 additions & 6 deletions imod_coupler/kernelwrappers/mf6_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABC
from abc import ABC, abstractmethod, abstractproperty
from collections.abc import Sequence
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -507,6 +507,17 @@ def get_flux(
self.q -= self.rhs
return self.q

@abstractproperty
def water_level(self) -> NDArray[np.float64]:
pass

@abstractmethod
def set_water_level(self, new_water_level: NDArray[np.float64]) -> None:
# Do not use @water_level.setter!
# Since instance.water_level[:] = ...
# Will NOT call the setter, only the accessor!
pass


class Mf6River(Mf6HeadBoundary):
nodelist: NDArray[np.int32]
Expand Down Expand Up @@ -540,8 +551,12 @@ def __init__(
def update_bottom_minimum(self) -> None:
self.bottom_minimum[:] = self.bottom_elevation[:]

def set_stage(self, new_stage: NDArray[np.float64]) -> None:
np.maximum(self.bottom_minimum, new_stage, out=self.stage)
@property
def water_level(self) -> NDArray[np.float64]:
return self.stage

def set_water_level(self, new_water_level: NDArray[np.float64]) -> None:
np.maximum(self.bottom_minimum, new_water_level, out=self.stage)


class Mf6Drainage(Mf6HeadBoundary):
Expand All @@ -568,7 +583,11 @@ def __init__(
self.elevation_minimum = self.elevation.copy()

def update_bottom_minimum(self) -> None:
self.elevation_minimum[:] = self.elevation
self.elevation_minimum[:] = self.elevation[:]

@property
def water_level(self) -> NDArray[np.float64]:
return self.elevation

def set_elevation(self, new_elevation: NDArray[np.float64]) -> None:
np.maximum(self.elevation_minimum, new_elevation, out=self.elevation)
def set_water_level(self, new_water_level: NDArray[np.float64]) -> None:
np.maximum(self.elevation_minimum, new_water_level, out=self.elevation)
4 changes: 2 additions & 2 deletions tests/test_mf6_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_mf6_river(
assert (mf6_river.nodelist != -1).any()

# This guards against setting below elevation.
mf6_river.set_stage(-123.0)
mf6_river.set_water_level(np.full_like(mf6_river.water_level, -123.0))
stage_address = mf6wrapper.get_var_address("STAGE", "GWF_1", "Oosterschelde")
stage = mf6wrapper.get_value_ptr(stage_address)
assert (stage > -10.0).all()
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_mf6_drainage(
assert (mf6_drainage.nodelist != -1).any()

# This guards against setting below elevation.
mf6_drainage.set_elevation(-123.0)
mf6_drainage.set_water_level(np.full_like(mf6_drainage.water_level, -123.0))
elev_address = mf6wrapper.get_var_address("ELEV", "GWF_1", "Drainage")
elev = mf6wrapper.get_value_ptr(elev_address)
assert (elev > -10.0).all()
Expand Down

0 comments on commit bf88ab4

Please sign in to comment.