-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor MB06 #51
Refactor MB06 #51
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,8 @@ | |
Utilities | ||
""" | ||
|
||
import warnings | ||
|
||
# from itertools import product | ||
from typing import Hashable, Iterable, Iterator, Optional | ||
|
||
|
@@ -80,7 +82,12 @@ def _calc_rmax(depth: DataArray) -> DataArray: | |
return rmax.fillna(0) | ||
|
||
|
||
def _smooth_MB06(depth: DataArray, rmax: float) -> DataArray: | ||
def _smooth_MB06( | ||
depth: DataArray, | ||
rmax: float, | ||
tol: float = 1.0e-8, | ||
max_iter: int = 10_000, | ||
) -> DataArray: | ||
""" | ||
This is NEMO implementation of the direct iterative method | ||
of Martinho and Batteen (2006). | ||
|
@@ -103,84 +110,62 @@ def _smooth_MB06(depth: DataArray, rmax: float) -> DataArray: | |
Bottom depth (units: m). | ||
rmax: float | ||
Maximum slope parameter allowed | ||
tol: float, default = 1.0e-8 | ||
Tolerance for the iterative method | ||
max_iter: int, default = 10000 | ||
Maximum number of iterations | ||
|
||
Returns | ||
------- | ||
DataArray | ||
Smooth version of the bottom topography with | ||
a maximum slope parameter < rmax (units: m). | ||
|
||
""" | ||
|
||
# set scaling factor used for smoothing | ||
# Set scaling factor used for smoothing | ||
zrfact = (1.0 - rmax) / (1.0 + rmax) | ||
|
||
# getting the actual numpy array | ||
# TO BE OPTIMISED | ||
da_zenv = depth.copy() | ||
zenv = da_zenv.data | ||
nj = zenv.shape[0] | ||
ni = zenv.shape[1] | ||
|
||
# initialise temporary evelope depth arrays | ||
ztmpi1 = zenv.copy() | ||
ztmpi2 = zenv.copy() | ||
ztmpj1 = zenv.copy() | ||
ztmpj2 = zenv.copy() | ||
|
||
# Computing the initial maximum slope parameter | ||
zrmax = 1.0 # np.nanmax(_calc_rmax(depth)) | ||
zri = np.ones(zenv.shape) # * zrmax | ||
zrj = np.ones(zenv.shape) # * zrmax | ||
|
||
tol = 1.0e-8 | ||
itr = 0 | ||
max_itr = 10000 | ||
|
||
while itr <= max_itr and (zrmax - rmax) > tol: | ||
|
||
itr += 1 | ||
zrmax = 0.0 | ||
# we set zrmax from previous r-values (zri and zrj) first | ||
# if set after current r-value calculation (as previously) | ||
# we could exit DO WHILE prematurely before checking r-value | ||
# of current zenv | ||
max_zri = np.nanmax(np.absolute(zri)) | ||
max_zrj = np.nanmax(np.absolute(zrj)) | ||
zrmax = np.nanmax([zrmax, max_zrj, max_zri]) | ||
|
||
print("Iter:", itr, "rmax: ", zrmax) | ||
|
||
zri *= 0.0 | ||
zrj *= 0.0 | ||
|
||
for j in range(nj - 1): | ||
for i in range(ni - 1): | ||
ip1 = np.minimum(i + 1, ni) | ||
jp1 = np.minimum(j + 1, nj) | ||
if zenv[j, i] > 0.0 and zenv[j, ip1] > 0.0: | ||
zri[j, i] = (zenv[j, ip1] - zenv[j, i]) / ( | ||
zenv[j, ip1] + zenv[j, i] | ||
) | ||
if zenv[j, i] > 0.0 and zenv[jp1, i] > 0.0: | ||
zrj[j, i] = (zenv[jp1, i] - zenv[j, i]) / ( | ||
zenv[jp1, i] + zenv[j, i] | ||
) | ||
if zri[j, i] > rmax: | ||
ztmpi1[j, i] = zenv[j, ip1] * zrfact | ||
if zri[j, i] < -rmax: | ||
ztmpi2[j, ip1] = zenv[j, i] * zrfact | ||
if zrj[j, i] > rmax: | ||
ztmpj1[j, i] = zenv[jp1, i] * zrfact | ||
if zrj[j, i] < -rmax: | ||
ztmpj2[jp1, i] = zenv[j, i] * zrfact | ||
|
||
ztmpi = np.maximum(ztmpi1, ztmpi2) | ||
ztmpj = np.maximum(ztmpj1, ztmpj2) | ||
zenv = np.maximum(zenv, np.maximum(ztmpi, ztmpj)) | ||
|
||
da_zenv.data = zenv | ||
return da_zenv | ||
# Initialize envelope bathymetry | ||
zenv = depth | ||
|
||
for _ in range(max_iter): | ||
|
||
# Initialize lists of DataArrays to concatenate | ||
all_ztmp = [] | ||
all_zr = [] | ||
for dim in zenv.dims: | ||
|
||
# Shifted arrays | ||
zenv_m1 = zenv.shift({dim: -1}) | ||
zenv_p1 = zenv.shift({dim: +1}) | ||
|
||
# Compute zr | ||
mb06 = (zenv_m1 - zenv) / (zenv_m1 + zenv) | ||
malmans2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
zr = mb06.where((zenv > 0) & (zenv_m1 > 0), 0) | ||
for dim_name in zenv.dims: | ||
zr[{dim_name: -1}] = 0 | ||
all_zr += [zr] | ||
|
||
# Compute ztmp | ||
zr_p1 = zr.shift({dim: +1}) | ||
all_ztmp += [zenv.where(zr <= rmax, zenv_m1 * zrfact)] | ||
all_ztmp += [zenv.where(zr_p1 >= -rmax, zenv_p1 * zrfact)] | ||
|
||
# Update envelope bathymetry | ||
zenv = xr.concat([zenv] + all_ztmp, "dummy_dim").max("dummy_dim") | ||
|
||
# Check target rmax | ||
zr = xr.concat(all_zr, "dummy_dim") | ||
if ((np.abs(zr) - rmax) <= tol).all(): | ||
return zenv | ||
|
||
# TODO: | ||
# Warning or error? | ||
warnings.warn( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I personally would put an error: from personal experience, usually it is sufficient tiincrease the tolerance of a couple of orders, an many tines it is relating with the precision of your data ... in this way we force the user to be aware and maybe make the input more consistent precision wise There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, changed to error! Maybe we should decrease the default number of iterations... I'd get pretty mad if I wait hours to get an error and no return :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree! |
||
"Iterative method did NOT converge." | ||
" You might want to increase the number of iterations and/or the tolerance." | ||
) | ||
return zenv | ||
|
||
|
||
def generate_cartesian_grid( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add something like: This is the xarray version of NEMO implementation ....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reworded, although I didn't explicitly mention xarray as we use it pretty much everywhere and it should be clear from the docs. Also I removed the units, any units should work fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree with your point ... it was just to make clear to the users (an us :) ) that this is only a more efficient version of the NEMO MB06 implementation