Skip to content

Commit

Permalink
Add explicit dtypes to more operations
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Mar 1, 2022
1 parent 2c14ece commit a14864a
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 77 deletions.
5 changes: 3 additions & 2 deletions grudge/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
THE SOFTWARE.
"""

from typing import Any, Optional
from pytools import memoize_method

from grudge.dof_desc import (
Expand Down Expand Up @@ -698,7 +699,7 @@ def order(self):

# {{{ Discretization-specific geometric properties

def nodes(self, dd=None):
def nodes(self, dd=None, dtype: Optional[np.dtype[Any]] = None):
r"""Return the nodes of a discretization specified by *dd*.
:arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one.
Expand All @@ -707,7 +708,7 @@ def nodes(self, dd=None):
"""
if dd is None:
dd = DD_VOLUME
return self.discr_from_dd(dd).nodes()
return self.discr_from_dd(dd).nodes(dtype)

def normal(self, dd):
r"""Get the unit normal to the specified surface discretization, *dd*.
Expand Down
110 changes: 72 additions & 38 deletions grudge/geometry/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,14 @@
"""


from typing import Optional, Any

import numpy as np

from arraycontext import thaw, freeze, ArrayContext
from meshmode.dof_array import DOFArray

from grudge.tools import to_real_dtype
from grudge import DiscretizationCollection
import grudge.dof_desc as dof_desc

Expand Down Expand Up @@ -111,7 +114,8 @@ def to_quad(vec):
def forward_metric_nth_derivative(
actx: ArrayContext, dcoll: DiscretizationCollection,
xyz_axis, ref_axes, dd=None,
*, _use_geoderiv_connection=False) -> DOFArray:
*, _use_geoderiv_connection=False,
dtype: Optional[np.dtype[Any]] = None) -> DOFArray:
r"""Pointwise metric derivatives representing repeated derivatives of the
physical coordinate enumerated by *xyz_axis*: :math:`x_{\mathrm{xyz\_axis}}`
with respect to the coordiantes on the reference element :math:`\xi_i`:
Expand Down Expand Up @@ -169,7 +173,8 @@ def forward_metric_nth_derivative(
vec = num_reference_derivative(
dcoll.discr_from_dd(inner_dd),
flat_ref_axes,
thaw(dcoll.discr_from_dd(inner_dd).nodes(), actx)[xyz_axis]
thaw(dcoll.discr_from_dd(inner_dd).nodes(
dtype=to_real_dtype(dtype)), actx)[xyz_axis]
)

return _geometry_to_quad_if_requested(
Expand All @@ -178,7 +183,8 @@ def forward_metric_nth_derivative(

def forward_metric_derivative_vector(
actx: ArrayContext, dcoll: DiscretizationCollection, rst_axis, dd=None,
*, _use_geoderiv_connection=False) -> np.ndarray:
*, _use_geoderiv_connection=False, dtype: Optional[np.dtype[Any]] = None
) -> np.ndarray:
r"""Computes an object array containing the forward metric derivatives
of each physical coordinate.
Expand All @@ -195,15 +201,17 @@ def forward_metric_derivative_vector(
return make_obj_array([
forward_metric_nth_derivative(
actx, dcoll, i, rst_axis, dd=dd,
_use_geoderiv_connection=_use_geoderiv_connection)
_use_geoderiv_connection=_use_geoderiv_connection,
dtype=dtype)
for i in range(dcoll.ambient_dim)
]
)


def forward_metric_derivative_mv(
actx: ArrayContext, dcoll: DiscretizationCollection, rst_axis, dd=None,
*, _use_geoderiv_connection=False) -> MultiVector:
*, _use_geoderiv_connection=False, dtype: Optional[np.dtype[Any]] = None
) -> MultiVector:
r"""Computes a :class:`pymbolic.geometric_algebra.MultiVector` containing
the forward metric derivatives of each physical coordinate.
Expand All @@ -220,13 +228,15 @@ def forward_metric_derivative_mv(
return MultiVector(
forward_metric_derivative_vector(
actx, dcoll, rst_axis, dd=dd,
_use_geoderiv_connection=_use_geoderiv_connection)
_use_geoderiv_connection=_use_geoderiv_connection,
dtype=dtype)
)


def forward_metric_derivative_mat(
actx: ArrayContext, dcoll: DiscretizationCollection, dd=None,
*, _use_geoderiv_connection=False) -> np.ndarray:
*, _use_geoderiv_connection=False,
dtype: Optional[np.dtype[Any]] = None) -> np.ndarray:
r"""Computes the forward metric derivative matrix, also commonly
called the Jacobian matrix, with entries defined as the
forward metric derivatives:
Expand Down Expand Up @@ -260,13 +270,15 @@ def forward_metric_derivative_mat(
for j in range(dim):
result[:, j] = forward_metric_derivative_vector(
actx, dcoll, j, dd=dd,
_use_geoderiv_connection=_use_geoderiv_connection)
_use_geoderiv_connection=_use_geoderiv_connection,
dtype=dtype)

return result


def first_fundamental_form(actx: ArrayContext, dcoll: DiscretizationCollection,
dd=None, *, _use_geoderiv_connection=False) -> np.ndarray:
dd=None, *, _use_geoderiv_connection=False,
dtype: Optional[np.dtype[Any]] = None) -> np.ndarray:
r"""Computes the first fundamental form using the Jacobian matrix:
.. math::
Expand Down Expand Up @@ -295,14 +307,16 @@ def first_fundamental_form(actx: ArrayContext, dcoll: DiscretizationCollection,
dd = DD_VOLUME

mder = forward_metric_derivative_mat(
actx, dcoll, dd=dd, _use_geoderiv_connection=_use_geoderiv_connection)
actx, dcoll, dd=dd, _use_geoderiv_connection=_use_geoderiv_connection,
dtype=dtype)

return mder.T.dot(mder)


def inverse_metric_derivative_mat(
actx: ArrayContext, dcoll: DiscretizationCollection, dd=None,
*, _use_geoderiv_connection=False) -> np.ndarray:
*, _use_geoderiv_connection=False, dtype: Optional[np.dtype[Any]] = None
) -> np.ndarray:
r"""Computes the inverse metric derivative matrix, which is
the inverse of the Jacobian (forward metric derivative) matrix.
Expand All @@ -324,15 +338,16 @@ def inverse_metric_derivative_mat(
for j in range(ambient_dim):
result[i, j] = inverse_metric_derivative(
actx, dcoll, i, j, dd=dd,
_use_geoderiv_connection=_use_geoderiv_connection
)
_use_geoderiv_connection=_use_geoderiv_connection,
dtype=dtype)

return result


def inverse_first_fundamental_form(
actx: ArrayContext, dcoll: DiscretizationCollection, dd=None,
*, _use_geoderiv_connection=False) -> np.ndarray:
*, _use_geoderiv_connection=False, dtype: Optional[np.dtype[Any]]
) -> np.ndarray:
r"""Computes the inverse of the first fundamental form:
.. math::
Expand Down Expand Up @@ -361,11 +376,13 @@ def inverse_first_fundamental_form(

if dcoll.ambient_dim == dim:
inv_mder = inverse_metric_derivative_mat(
actx, dcoll, dd=dd, _use_geoderiv_connection=_use_geoderiv_connection)
actx, dcoll, dd=dd, _use_geoderiv_connection=_use_geoderiv_connection,
dtype=dtype)
inv_form1 = inv_mder.dot(inv_mder.T)
else:
form1 = first_fundamental_form(
actx, dcoll, dd=dd, _use_geoderiv_connection=_use_geoderiv_connection)
actx, dcoll, dd=dd, _use_geoderiv_connection=_use_geoderiv_connection,
dtype=dtype)

if dim == 1:
inv_form1 = 1.0 / form1
Expand All @@ -383,7 +400,7 @@ def inverse_first_fundamental_form(

def inverse_metric_derivative(
actx: ArrayContext, dcoll: DiscretizationCollection, rst_axis, xyz_axis, dd,
*, _use_geoderiv_connection=False
*, _use_geoderiv_connection=False, dtype: Optional[np.dtype[Any]] = None
) -> DOFArray:
r"""Computes the inverse metric derivative of the physical
coordinate enumerated by *xyz_axis* with respect to the
Expand All @@ -409,7 +426,8 @@ def inverse_metric_derivative(
par_vecs = [
forward_metric_derivative_mv(
actx, dcoll, rst, dd,
_use_geoderiv_connection=_use_geoderiv_connection)
_use_geoderiv_connection=_use_geoderiv_connection,
dtype=dtype)
for rst in range(dim)]

# Yay Cramer's rule!
Expand Down Expand Up @@ -442,7 +460,8 @@ def outprod_with_unit(i, at):
def inverse_surface_metric_derivative(
actx: ArrayContext, dcoll: DiscretizationCollection,
rst_axis, xyz_axis, dd=None,
*, _use_geoderiv_connection=False):
*, _use_geoderiv_connection=False,
dtype: Optional[np.dtype[Any]]):
r"""Computes the inverse surface metric derivative of the physical
coordinate enumerated by *xyz_axis* with respect to the
reference axis *rst_axis*. These geometric terms are used in the
Expand All @@ -467,24 +486,24 @@ def inverse_surface_metric_derivative(
dd = dof_desc.as_dofdesc(dd)

if ambient_dim == dim:
result = inverse_metric_derivative(
return inverse_metric_derivative(
actx, dcoll, rst_axis, xyz_axis, dd=dd,
_use_geoderiv_connection=_use_geoderiv_connection
)
_use_geoderiv_connection=_use_geoderiv_connection,
dtype=dtype)
else:
inv_form1 = inverse_first_fundamental_form(actx, dcoll, dd=dd)
result = sum(
return sum(
inv_form1[rst_axis, d]*forward_metric_nth_derivative(
actx, dcoll, xyz_axis, d, dd=dd,
_use_geoderiv_connection=_use_geoderiv_connection
_use_geoderiv_connection=_use_geoderiv_connection,
dtype=dtype,
) for d in range(dim))

return result


def inverse_surface_metric_derivative_mat(
actx: ArrayContext, dcoll: DiscretizationCollection, dd=None,
*, times_area_element=False, _use_geoderiv_connection=False):
*, times_area_element=False, _use_geoderiv_connection=False,
dtype: Optional[np.dtype[Any]] = None):
r"""Computes the matrix of inverse surface metric derivatives, indexed by
``(xyz_axis, rst_axis)``. It returns all values of
:func:`inverse_surface_metric_derivative_mat` in cached matrix form.
Expand All @@ -505,7 +524,7 @@ def inverse_surface_metric_derivative_mat(

@memoize_in(dcoll, (inverse_surface_metric_derivative_mat, dd,
times_area_element, _use_geoderiv_connection))
def _inv_surf_metric_deriv():
def _inv_surf_metric_deriv(dtype):
if times_area_element:
multiplier = area_element(actx, dcoll, dd=dd,
_use_geoderiv_connection=_use_geoderiv_connection)
Expand All @@ -517,13 +536,17 @@ def _inv_surf_metric_deriv():
multiplier
* inverse_surface_metric_derivative(actx, dcoll,
rst_axis, xyz_axis, dd=dd,
_use_geoderiv_connection=_use_geoderiv_connection)
_use_geoderiv_connection=_use_geoderiv_connection,
dtype=dtype)
for rst_axis in range(dcoll.dim)])
for xyz_axis in range(dcoll.ambient_dim)])

return freeze(mat, actx)

return thaw(_inv_surf_metric_deriv(), actx)
if dtype is not None:
dtype = to_real_dtype(dtype)

return thaw(_inv_surf_metric_deriv(dtype), actx)


def _signed_face_ones(
Expand Down Expand Up @@ -557,7 +580,8 @@ def _signed_face_ones(

def parametrization_derivative(
actx: ArrayContext, dcoll: DiscretizationCollection, dd,
*, _use_geoderiv_connection=False) -> MultiVector:
*, _use_geoderiv_connection=False,
dtype: Optional[np.dtype[Any]]) -> MultiVector:
r"""Computes the product of forward metric derivatives spanning the
tangent space with topological dimension *dim*.
Expand Down Expand Up @@ -585,13 +609,15 @@ def parametrization_derivative(
return product(
forward_metric_derivative_mv(
actx, dcoll, rst_axis, dd,
_use_geoderiv_connection=_use_geoderiv_connection)
_use_geoderiv_connection=_use_geoderiv_connection,
dtype=dtype)
for rst_axis in range(dim)
)


def pseudoscalar(actx: ArrayContext, dcoll: DiscretizationCollection,
dd=None, *, _use_geoderiv_connection=False) -> MultiVector:
dd=None, *, _use_geoderiv_connection=False,
dtype: Optional[np.dtype[Any]]) -> MultiVector:
r"""Computes the field of pseudoscalars for the domain/discretization
identified by *dd*.
Expand All @@ -607,12 +633,14 @@ def pseudoscalar(actx: ArrayContext, dcoll: DiscretizationCollection,

return parametrization_derivative(
actx, dcoll, dd,
_use_geoderiv_connection=_use_geoderiv_connection).project_max_grade()
_use_geoderiv_connection=_use_geoderiv_connection,
dtype=dtype).project_max_grade()


def area_element(
actx: ArrayContext, dcoll: DiscretizationCollection, dd=None,
*, _use_geoderiv_connection=False
*, _use_geoderiv_connection=False,
dtype: Optional[np.dtype[Any]] = None
) -> DOFArray:
r"""Computes the scale factor used to transform integrals from reference
to global space.
Expand All @@ -623,22 +651,28 @@ def area_element(
Defaults to the base volume discretization.
:arg _use_geoderiv_connection: For internal use. See
:func:`forward_metric_nth_derivative` for an explanation.
:arg dtype: the :class:`numpy.dtype` with which to return the area element
data.
:returns: a :class:`~meshmode.dof_array.DOFArray` containing the transformed
volumes for each element.
"""
if dd is None:
dd = DD_VOLUME

@memoize_in(dcoll, (area_element, dd, _use_geoderiv_connection))
def _area_elements():
def _area_elements(dtype: np.dtype[Any]):
result = actx.np.sqrt(
pseudoscalar(
actx, dcoll, dd=dd,
_use_geoderiv_connection=_use_geoderiv_connection).norm_squared())
_use_geoderiv_connection=_use_geoderiv_connection,
dtype=dtype).norm_squared())

return freeze(result, actx)

return thaw(_area_elements(), actx)
if dtype is not None:
dtype = to_real_dtype(dtype)

return thaw(_area_elements(dtype), actx)

# }}}

Expand Down
Loading

0 comments on commit a14864a

Please sign in to comment.