diff --git a/pydomcfg/utils.py b/pydomcfg/utils.py index b670690..27f13b8 100644 --- a/pydomcfg/utils.py +++ b/pydomcfg/utils.py @@ -2,7 +2,6 @@ Utilities """ -# from itertools import product from typing import Hashable, Iterable, Iterator, Optional import numpy as np @@ -80,10 +79,15 @@ def _calc_rmax(depth: DataArray) -> DataArray: return rmax.fillna(0) -def _smooth_MB06(depth: DataArray, rmax: float) -> DataArray: +def _smooth_MB06( + depth: DataArray, + rmax: float, + tol: float = 1.0e-8, + max_iter: int = 10_000, +) -> DataArray: """ - This is NEMO implementation of the direct iterative method - of Martinho and Batteen (2006). + Direct iterative method of Martinho and Batteen (2006) consistent + with NEMO implementation. The algorithm ensures that @@ -100,87 +104,62 @@ def _smooth_MB06(depth: DataArray, rmax: float) -> DataArray: Parameters ---------- depth: DataArray - Bottom depth (units: m). + Bottom depth. rmax: float Maximum slope parameter allowed + tol: float, default = 1.0e-8 + Tolerance for the iterative method + max_iter: int, default = 10000 + Maximum number of iterations Returns ------- DataArray Smooth version of the bottom topography with - a maximum slope parameter < rmax (units: m). - + a maximum slope parameter < rmax. """ - # set scaling factor used for smoothing + # Set scaling factor used for smoothing zrfact = (1.0 - rmax) / (1.0 + rmax) - # getting the actual numpy array - # TO BE OPTIMISED - da_zenv = depth.copy() - zenv = da_zenv.data - nj = zenv.shape[0] - ni = zenv.shape[1] - - # initialise temporary evelope depth arrays - ztmpi1 = zenv.copy() - ztmpi2 = zenv.copy() - ztmpj1 = zenv.copy() - ztmpj2 = zenv.copy() - - # Computing the initial maximum slope parameter - zrmax = 1.0 # np.nanmax(_calc_rmax(depth)) - zri = np.ones(zenv.shape) # * zrmax - zrj = np.ones(zenv.shape) # * zrmax - - tol = 1.0e-8 - itr = 0 - max_itr = 10000 - - while itr <= max_itr and (zrmax - rmax) > tol: - - itr += 1 - zrmax = 0.0 - # we set zrmax from previous r-values (zri and zrj) first - # if set after current r-value calculation (as previously) - # we could exit DO WHILE prematurely before checking r-value - # of current zenv - max_zri = np.nanmax(np.absolute(zri)) - max_zrj = np.nanmax(np.absolute(zrj)) - zrmax = np.nanmax([zrmax, max_zrj, max_zri]) - - print("Iter:", itr, "rmax: ", zrmax) - - zri *= 0.0 - zrj *= 0.0 - - for j in range(nj - 1): - for i in range(ni - 1): - ip1 = np.minimum(i + 1, ni) - jp1 = np.minimum(j + 1, nj) - if zenv[j, i] > 0.0 and zenv[j, ip1] > 0.0: - zri[j, i] = (zenv[j, ip1] - zenv[j, i]) / ( - zenv[j, ip1] + zenv[j, i] - ) - if zenv[j, i] > 0.0 and zenv[jp1, i] > 0.0: - zrj[j, i] = (zenv[jp1, i] - zenv[j, i]) / ( - zenv[jp1, i] + zenv[j, i] - ) - if zri[j, i] > rmax: - ztmpi1[j, i] = zenv[j, ip1] * zrfact - if zri[j, i] < -rmax: - ztmpi2[j, ip1] = zenv[j, i] * zrfact - if zrj[j, i] > rmax: - ztmpj1[j, i] = zenv[jp1, i] * zrfact - if zrj[j, i] < -rmax: - ztmpj2[jp1, i] = zenv[j, i] * zrfact - - ztmpi = np.maximum(ztmpi1, ztmpi2) - ztmpj = np.maximum(ztmpj1, ztmpj2) - zenv = np.maximum(zenv, np.maximum(ztmpi, ztmpj)) - - da_zenv.data = zenv - return da_zenv + # Initialize envelope bathymetry + zenv = depth + + for _ in range(max_iter): + + # Initialize lists of DataArrays to concatenate + all_ztmp = [] + all_zr = [] + for dim in zenv.dims: + + # Shifted arrays + zenv_m1 = zenv.shift({dim: -1}) + zenv_p1 = zenv.shift({dim: +1}) + + # Compute zr + zr = (zenv_m1 - zenv) / (zenv_m1 + zenv) + zr = zr.where((zenv > 0) & (zenv_m1 > 0), 0) + for dim_name in zenv.dims: + zr[{dim_name: -1}] = 0 + all_zr += [zr] + + # Compute ztmp + zr_p1 = zr.shift({dim: +1}) + all_ztmp += [zenv.where(zr <= rmax, zenv_m1 * zrfact)] + all_ztmp += [zenv.where(zr_p1 >= -rmax, zenv_p1 * zrfact)] + + # Update envelope bathymetry + zenv = xr.concat([zenv] + all_ztmp, "dummy_dim").max("dummy_dim") + + # Check target rmax + zr = xr.concat(all_zr, "dummy_dim") + if ((np.abs(zr) - rmax) <= tol).all(): + return zenv + + raise ValueError( + "Iterative method did NOT converge." + " You might want to increase the number of iterations and/or the tolerance." + ) def generate_cartesian_grid(