Skip to content

Commit

Permalink
Refactor MB06 (#51)
Browse files Browse the repository at this point in the history
* refactor to avoid loops

* apply Diego's suggestions

* remove useless return after error
  • Loading branch information
malmans2 authored Jun 30, 2021
1 parent 6750533 commit 8e08398
Showing 1 changed file with 53 additions and 74 deletions.
127 changes: 53 additions & 74 deletions pydomcfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Utilities
"""

# from itertools import product
from typing import Hashable, Iterable, Iterator, Optional

import numpy as np
Expand Down Expand Up @@ -80,10 +79,15 @@ def _calc_rmax(depth: DataArray) -> DataArray:
return rmax.fillna(0)


def _smooth_MB06(depth: DataArray, rmax: float) -> DataArray:
def _smooth_MB06(
depth: DataArray,
rmax: float,
tol: float = 1.0e-8,
max_iter: int = 10_000,
) -> DataArray:
"""
This is NEMO implementation of the direct iterative method
of Martinho and Batteen (2006).
Direct iterative method of Martinho and Batteen (2006) consistent
with NEMO implementation.
The algorithm ensures that
Expand All @@ -100,87 +104,62 @@ def _smooth_MB06(depth: DataArray, rmax: float) -> DataArray:
Parameters
----------
depth: DataArray
Bottom depth (units: m).
Bottom depth.
rmax: float
Maximum slope parameter allowed
tol: float, default = 1.0e-8
Tolerance for the iterative method
max_iter: int, default = 10000
Maximum number of iterations
Returns
-------
DataArray
Smooth version of the bottom topography with
a maximum slope parameter < rmax (units: m).
a maximum slope parameter < rmax.
"""

# set scaling factor used for smoothing
# Set scaling factor used for smoothing
zrfact = (1.0 - rmax) / (1.0 + rmax)

# getting the actual numpy array
# TO BE OPTIMISED
da_zenv = depth.copy()
zenv = da_zenv.data
nj = zenv.shape[0]
ni = zenv.shape[1]

# initialise temporary evelope depth arrays
ztmpi1 = zenv.copy()
ztmpi2 = zenv.copy()
ztmpj1 = zenv.copy()
ztmpj2 = zenv.copy()

# Computing the initial maximum slope parameter
zrmax = 1.0 # np.nanmax(_calc_rmax(depth))
zri = np.ones(zenv.shape) # * zrmax
zrj = np.ones(zenv.shape) # * zrmax

tol = 1.0e-8
itr = 0
max_itr = 10000

while itr <= max_itr and (zrmax - rmax) > tol:

itr += 1
zrmax = 0.0
# we set zrmax from previous r-values (zri and zrj) first
# if set after current r-value calculation (as previously)
# we could exit DO WHILE prematurely before checking r-value
# of current zenv
max_zri = np.nanmax(np.absolute(zri))
max_zrj = np.nanmax(np.absolute(zrj))
zrmax = np.nanmax([zrmax, max_zrj, max_zri])

print("Iter:", itr, "rmax: ", zrmax)

zri *= 0.0
zrj *= 0.0

for j in range(nj - 1):
for i in range(ni - 1):
ip1 = np.minimum(i + 1, ni)
jp1 = np.minimum(j + 1, nj)
if zenv[j, i] > 0.0 and zenv[j, ip1] > 0.0:
zri[j, i] = (zenv[j, ip1] - zenv[j, i]) / (
zenv[j, ip1] + zenv[j, i]
)
if zenv[j, i] > 0.0 and zenv[jp1, i] > 0.0:
zrj[j, i] = (zenv[jp1, i] - zenv[j, i]) / (
zenv[jp1, i] + zenv[j, i]
)
if zri[j, i] > rmax:
ztmpi1[j, i] = zenv[j, ip1] * zrfact
if zri[j, i] < -rmax:
ztmpi2[j, ip1] = zenv[j, i] * zrfact
if zrj[j, i] > rmax:
ztmpj1[j, i] = zenv[jp1, i] * zrfact
if zrj[j, i] < -rmax:
ztmpj2[jp1, i] = zenv[j, i] * zrfact

ztmpi = np.maximum(ztmpi1, ztmpi2)
ztmpj = np.maximum(ztmpj1, ztmpj2)
zenv = np.maximum(zenv, np.maximum(ztmpi, ztmpj))

da_zenv.data = zenv
return da_zenv
# Initialize envelope bathymetry
zenv = depth

for _ in range(max_iter):

# Initialize lists of DataArrays to concatenate
all_ztmp = []
all_zr = []
for dim in zenv.dims:

# Shifted arrays
zenv_m1 = zenv.shift({dim: -1})
zenv_p1 = zenv.shift({dim: +1})

# Compute zr
zr = (zenv_m1 - zenv) / (zenv_m1 + zenv)
zr = zr.where((zenv > 0) & (zenv_m1 > 0), 0)
for dim_name in zenv.dims:
zr[{dim_name: -1}] = 0
all_zr += [zr]

# Compute ztmp
zr_p1 = zr.shift({dim: +1})
all_ztmp += [zenv.where(zr <= rmax, zenv_m1 * zrfact)]
all_ztmp += [zenv.where(zr_p1 >= -rmax, zenv_p1 * zrfact)]

# Update envelope bathymetry
zenv = xr.concat([zenv] + all_ztmp, "dummy_dim").max("dummy_dim")

# Check target rmax
zr = xr.concat(all_zr, "dummy_dim")
if ((np.abs(zr) - rmax) <= tol).all():
return zenv

raise ValueError(
"Iterative method did NOT converge."
" You might want to increase the number of iterations and/or the tolerance."
)


def generate_cartesian_grid(
Expand Down

0 comments on commit 8e08398

Please sign in to comment.