Skip to content

Commit

Permalink
remove deprecated uses of flattening
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Oct 21, 2021
1 parent c94f2fb commit d2c7e28
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 28 deletions.
8 changes: 4 additions & 4 deletions meshmode/discretization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

from pytools import memoize_method, Record
from pytools.obj_array import make_obj_array
from arraycontext import thaw
from meshmode.dof_array import DOFArray, flatten
from arraycontext import thaw, flatten
from meshmode.dof_array import DOFArray

from modepy.shapes import Shape, Simplex, Hypercube

Expand Down Expand Up @@ -139,7 +139,7 @@ def _resample_to_numpy(conn, vis_discr, vec, *, stack=False, by_group=False):
from meshmode.dof_array import check_dofarray_against_discr
check_dofarray_against_discr(vis_discr, vec)

return actx.to_numpy(flatten(vec))
return actx.to_numpy(flatten(vec, actx))
else:
raise TypeError(f"unsupported array type: {type(vec).__name__}")

Expand Down Expand Up @@ -523,7 +523,7 @@ def copy_with_same_connectivity(self, actx, discr, skip_tests=False):
def _vis_nodes_numpy(self):
actx = self.vis_discr._setup_actx
return np.array([
actx.to_numpy(flatten(thaw(ary, actx)))
actx.to_numpy(flatten(ary, actx))
for ary in self.vis_discr.nodes()
])

Expand Down
14 changes: 7 additions & 7 deletions test/test_chained.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])

from arraycontext import thaw
from meshmode.dof_array import flatten_to_numpy, flat_norm
from arraycontext import thaw, flatten
from meshmode.dof_array import flat_norm

import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -244,9 +244,9 @@ def f(x):

x = thaw(connections[0].from_discr.nodes(), actx)
fx = f(x)
f1 = resample_mat @ flatten_to_numpy(actx, fx)
f2 = flatten_to_numpy(actx, chained(fx))
f3 = flatten_to_numpy(actx, connections[1](connections[0](fx)))
f1 = resample_mat @ actx.to_numpy(flatten(fx, actx))
f2 = actx.to_numpy(flatten(chained(fx), actx))
f3 = actx.to_numpy(flatten(connections[1](connections[0](fx)), actx))

assert np.allclose(f1, f2)
assert np.allclose(f2, f3)
Expand Down Expand Up @@ -311,13 +311,13 @@ def f(x):
fx = f(x)

t_start = time.time()
f1 = flatten_to_numpy(actx, direct(fx))
f1 = actx.to_numpy(flatten(direct(fx), actx))
t_end = time.time()
if visualize:
print("[TIME] Direct: {:.5e}".format(t_end - t_start))

t_start = time.time()
f2 = flatten_to_numpy(actx, chained(fx))
f2 = actx.to_numpy(flatten(chained(fx), actx))
t_end = time.time()
if visualize:
print("[TIME] Chained: {:.5e}".format(t_end - t_start))
Expand Down
18 changes: 10 additions & 8 deletions test/test_meshmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import numpy.linalg as la

import meshmode # noqa: F401
from arraycontext import thaw
from arraycontext import thaw, flatten

from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from arraycontext import pytest_generate_tests_for_array_contexts
Expand All @@ -44,7 +44,7 @@
LegendreGaussLobattoTensorProductGroupFactory
)
from meshmode.mesh import Mesh, BTAG_ALL
from meshmode.dof_array import flatten_to_numpy, flat_norm
from meshmode.dof_array import flat_norm
from meshmode.discretization.connection import \
FACE_RESTR_ALL, FACE_RESTR_INTERIOR
import meshmode.mesh.generation as mgen
Expand Down Expand Up @@ -181,9 +181,10 @@ def f(x):
make_direct_full_resample_matrix
mat = actx.to_numpy(
make_direct_full_resample_matrix(actx, bdry_connection))
bdry_f_2_by_mat = mat.dot(flatten_to_numpy(actx, vol_f))
bdry_f_2_by_mat = mat.dot(actx.to_numpy(flatten(vol_f, actx)))

mat_error = la.norm(flatten_to_numpy(actx, bdry_f_2) - bdry_f_2_by_mat)
mat_error = la.norm(
actx.to_numpy(flatten(bdry_f_2, actx)) - bdry_f_2_by_mat)
assert mat_error < 1e-14, mat_error

err = flat_norm(bdry_f-bdry_f_2, np.inf)
Expand Down Expand Up @@ -489,8 +490,9 @@ def test_orientation_3d(actx_factory, what, mesh_gen_func, visualize=False):
normal_outward_expr = (
sym.normal(mesh.ambient_dim) | sym.nodes(mesh.ambient_dim))

normal_outward_check = flatten_to_numpy(actx,
bind(discr, normal_outward_expr)(actx).as_scalar()) > 0
normal_outward_check = actx.to_numpy(flatten(
bind(discr, normal_outward_expr)(actx).as_scalar(),
actx)) > 0

assert normal_outward_check.all(), normal_outward_check

Expand Down Expand Up @@ -593,7 +595,7 @@ def test_sanity_single_element(actx_factory, dim, mesh_order, group_cls,
| (sym.nodes(dim) + 0.5*sym.ones_vec(dim)),
)(actx).as_scalar()

normal_outward_check = flatten_to_numpy(actx, normal_outward_check > 0)
normal_outward_check = actx.to_numpy(flatten(normal_outward_check > 0, actx))
assert normal_outward_check.all(), normal_outward_check

# }}}
Expand Down Expand Up @@ -748,7 +750,7 @@ def test_sanity_balls(actx_factory, src_file, dim, mesh_order, visualize=False):
sym.normal(mesh.ambient_dim) | sym.nodes(mesh.ambient_dim),
)(actx).as_scalar()

normal_outward_check = flatten_to_numpy(actx, normal_outward_check > 0)
normal_outward_check = actx.to_numpy(flatten(normal_outward_check > 0, actx))
assert normal_outward_check.all(), normal_outward_check

# }}}
Expand Down
22 changes: 13 additions & 9 deletions test/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import numpy as np
import pyopencl as cl

from meshmode.dof_array import flatten, unflatten, flat_norm
from meshmode.dof_array import flat_norm

from arraycontext import thaw
from arraycontext import thaw, flatten, unflatten
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from arraycontext import pytest_generate_tests_for_array_contexts
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
Expand Down Expand Up @@ -488,9 +488,10 @@ def f(x):
remote_f = conn(true_local_f)

# 2.
send_reqs.append(mpi_comm.isend(actx.to_numpy(flatten(remote_f)),
dest=i_remote_part,
tag=TAG_SEND_REMOTE_NODES))
send_reqs.append(mpi_comm.isend(
actx.to_numpy(flatten(remote_f, actx)),
dest=i_remote_part,
tag=TAG_SEND_REMOTE_NODES))

# 3.
buffers = {}
Expand Down Expand Up @@ -518,9 +519,12 @@ def f(x):
send_reqs = []
for i_remote_part in connected_parts:
conn = remote_to_local_bdry_conns[i_remote_part]
local_f = unflatten(actx, conn.from_discr,
actx.from_numpy(remote_to_local_f_data[i_remote_part]))
remote_f = actx.to_numpy(flatten(conn(local_f)))

local_f = unflatten(
thaw(conn.from_discr.nodes()[0], actx),
actx.from_numpy(remote_to_local_f_data[i_remote_part]),
actx)
remote_f = actx.to_numpy(flatten(conn(local_f), actx))

# 5.
send_reqs.append(mpi_comm.isend(remote_f,
Expand Down Expand Up @@ -554,7 +558,7 @@ def f(x):
bdry_discr = local_bdry_conns[i_remote_part].to_discr
bdry_x = thaw(bdry_discr.nodes()[0], actx)

true_local_f = actx.to_numpy(flatten(f(bdry_x)))
true_local_f = actx.to_numpy(flatten(f(bdry_x), actx))
local_f = local_f_data[i_remote_part]

from numpy.linalg import norm
Expand Down

0 comments on commit d2c7e28

Please sign in to comment.