Skip to content

Commit

Permalink
fix porting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Sep 23, 2022
1 parent 79690a1 commit 6737ff8
Show file tree
Hide file tree
Showing 16 changed files with 235 additions and 209 deletions.
23 changes: 19 additions & 4 deletions pytential/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
THE SOFTWARE.
"""

from meshmode.array_context import (
from meshmode.array_context import ( # noqa: F401
PyOpenCLArrayContext as MeshmodePyOpenCLArrayContext)
from sumpy.array_context import ( # noqa: F401
from sumpy.array_context import ( # noqa: F401
PyOpenCLArrayContext as SumpyPyOpenCLArrayContext,
make_loopy_program)
from boxtree.array_context import dataclass_array_container # noqa: F401
from boxtree.array_context import dataclass_array_container # noqa: F401
from arraycontext.pytest import (
_PytestPyOpenCLArrayContextFactoryWithClass,
register_pytest_array_context_factory)
Expand All @@ -39,8 +39,23 @@

class PyOpenCLArrayContext(SumpyPyOpenCLArrayContext):
def transform_loopy_program(self, t_unit):
kernel = t_unit.default_entrypoint
options = kernel.options

if not options.return_dict or not options.no_numpy:
raise ValueError(
"loopy kernels passed to 'call_loopy' must have 'return_dict' "
"and 'no_numpy' options set. Did you use 'make_loopy_program' "
f"to create the kernel '{kernel.name}'?")

# FIXME: this probably needs some proper logic
return MeshmodePyOpenCLArrayContext.transform_loopy_program(self, t_unit)
from meshmode.array_context import _transform_loopy_inner
transformed_t_unit = _transform_loopy_inner(t_unit)

if transformed_t_unit is not None:
return transformed_t_unit

return t_unit

# }}}

Expand Down
12 changes: 6 additions & 6 deletions pytential/linalg/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ def partition_by_nodes(
from pytential.qbx.utils import tree_code_container
tcc = tree_code_container(lpot_source._setup_actx)

tree, _ = tcc.build_tree()(actx.queue,
tree, _ = tcc.build_tree()(actx,
particles=flatten(
actx.thaw(discr.nodes()), actx, leaf_class=DOFArray
),
max_particles_in_box=max_particles_in_box,
kind=tree_kind)

from boxtree import box_flags_enum
tree = tree.get(actx.queue)
tree = actx.to_numpy(tree)
leaf_boxes, = (tree.box_flags & box_flags_enum.HAS_CHILDREN == 0).nonzero()

indices = np.empty(len(leaf_boxes), dtype=object)
Expand Down Expand Up @@ -687,12 +687,12 @@ def prg():
from pytential.qbx.utils import tree_code_container
tcc = tree_code_container(lpot_source._setup_actx)

tree, _ = tcc.build_tree()(actx.queue, sources,
tree, _ = tcc.build_tree()(actx, sources,
max_particles_in_box=max_particles_in_box)
query, _ = tcc.build_area_query()(actx.queue, tree, pxy.centers, pxy.radii)
query, _ = tcc.build_area_query()(actx, tree, pxy.centers, pxy.radii)

tree = tree.get(actx.queue)
query = query.get(actx.queue)
tree = actx.to_numpy(tree)
query = actx.to_numpy(query)

# }}}

Expand Down
31 changes: 15 additions & 16 deletions pytential/qbx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,14 +436,15 @@ def exec_compute_potential_insn(self, actx, insn, bound_expr, evaluate,
else:
func = self.exec_compute_potential_insn_fmm

def drive_fmm(wrangler, strengths, geo_data, kernel, kernel_arguments):
def drive_fmm(
actx, wrangler, strengths, geo_data, kernel, kernel_arguments):
del geo_data, kernel, kernel_arguments
from pytential.qbx.fmm import drive_fmm
if return_timing_data:
timing_data = {}
else:
timing_data = None
return drive_fmm(wrangler, strengths, timing_data), timing_data
return drive_fmm(actx, wrangler, strengths, timing_data), timing_data

extra_args["fmm_driver"] = drive_fmm

Expand Down Expand Up @@ -471,26 +472,25 @@ def cost_model_compute_potential_insn(self, actx, insn, bound_expr, evaluate,
raise NotImplementedError("perf modeling direct evaluations")

def drive_cost_model(
wrangler, strengths, geo_data, kernel, kernel_arguments):
actx, wrangler, strengths, geo_data, kernel, kernel_arguments):

if per_box:
cost_model_result, metadata = self.cost_model.qbx_cost_per_box(
actx.queue, geo_data, kernel, kernel_arguments,
actx, geo_data, kernel, kernel_arguments,
calibration_params
)
else:
cost_model_result, metadata = self.cost_model.qbx_cost_per_stage(
actx.queue, geo_data, kernel, kernel_arguments,
actx, geo_data, kernel, kernel_arguments,
calibration_params
)

from pytools.obj_array import obj_array_vectorize
from functools import partial
return (
obj_array_vectorize(
partial(wrangler.finalize_potentials,
template_ary=strengths[0]),
wrangler.full_output_zeros(strengths[0])),
partial(wrangler.finalize_potentials, actx),
wrangler.full_output_zeros(actx)),
(cost_model_result, metadata))

return self._dispatch_compute_potential_insn(
Expand Down Expand Up @@ -538,7 +538,7 @@ def _tree_indep_data_for_wrangler(self, source_kernels, target_kernels):
from pytential.qbx.fmm import \
QBXSumpyTreeIndependentDataForWrangler
return QBXSumpyTreeIndependentDataForWrangler(
self.cl_context,
self._setup_actx,
fmm_mpole_factory, fmm_local_factory, qbx_local_factory,
target_kernels=target_kernels, source_kernels=source_kernels)

Expand All @@ -551,7 +551,7 @@ def _tree_indep_data_for_wrangler(self, source_kernels, target_kernels):
from pytential.qbx.fmmlib import \
QBXFMMLibTreeIndependentDataForWrangler
return QBXFMMLibTreeIndependentDataForWrangler(
self.cl_context,
self._setup_actx,
multipole_expansion_factory=fmm_mpole_factory,
local_expansion_factory=fmm_local_factory,
qbx_local_expansion_factory=qbx_local_factory,
Expand Down Expand Up @@ -665,7 +665,7 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
# Execute global QBX.
all_potentials_on_every_target, extra_outputs = (
fmm_driver(
wrangler, flat_strengths, geo_data,
actx, wrangler, flat_strengths, geo_data,
base_kernel, kernel_extra_kwargs))

results = []
Expand Down Expand Up @@ -746,7 +746,7 @@ def get_qbx_target_numberer(self, dtype):
assert dtype == np.int32
from pyopencl.scan import GenericScanKernel
return GenericScanKernel(
self.cl_context, np.int32,
self._setup_actx.context, np.int32,
arguments="int *tgt_to_qbx_center, int *qbx_tgt_number, int *count",
input_expr="tgt_to_qbx_center[i] >= 0 ? 1 : 0",
scan_expr="a+b", neutral="0",
Expand Down Expand Up @@ -837,7 +837,6 @@ def _flat_centers(dofdesc, qbx_forced_limit):

other_outputs[(o.target_name, qbx_forced_limit)].append((i, o))

queue = actx.queue
results = [None] * len(insn.outputs)

# }}}
Expand All @@ -854,7 +853,7 @@ def _flat_centers(dofdesc, qbx_forced_limit):
target_name.geometry, target_name.discr_stage)
flat_target_nodes = _flat_nodes(target_name)

evt, output_for_each_kernel = lpot_applier(actx,
output_for_each_kernel = lpot_applier(actx,
targets=flat_target_nodes,
sources=flat_source_nodes,
centers=_flat_centers(target_name, qbx_forced_limit),
Expand Down Expand Up @@ -885,7 +884,7 @@ def _flat_centers(dofdesc, qbx_forced_limit):
flat_target_nodes = _flat_nodes(target_name)

# FIXME: (Somewhat wastefully) compute P2P for all targets
evt, output_for_each_kernel = p2p(queue,
output_for_each_kernel = p2p(actx,
targets=flat_target_nodes,
sources=flat_source_nodes,
strength=flat_strengths,
Expand All @@ -911,7 +910,7 @@ def _flat_centers(dofdesc, qbx_forced_limit):

qbx_tgt_numberer(
tgt_to_qbx_center, qbx_tgt_numbers, qbx_tgt_count,
queue=queue)
queue=actx.queue)

qbx_tgt_count = int(actx.to_numpy(qbx_tgt_count).item())
if (abs(qbx_forced_limit) == 1 and qbx_tgt_count < target_discr.ndofs):
Expand Down
9 changes: 5 additions & 4 deletions pytential/qbx/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,8 @@ def get_nqbx_centers_per_tgt_box(self,

# Build a mask (weight) of whether a target is a global qbx center
global_qbx_centers_tree_order = (
tree.sorted_target_ids[global_qbx_centers])
actx.thaw(tree.sorted_target_ids)[global_qbx_centers]
)

global_qbx_center_weight = actx.zeros(
tree.ntargets, dtype=tree.particle_id_dtype
Expand Down Expand Up @@ -779,7 +780,7 @@ def process_m2qbxl(self, actx: PyOpenCLArrayContext, geo_data, m2qbxl_cost):
nqbx_centers_itgt_box,
ssn.starts,
nm2qbxl,
m2qbxl_cost[isrc_level].reshape(-1)[0],
actx.to_numpy(m2qbxl_cost[isrc_level]).reshape(-1)[0],
queue=actx.queue
)

Expand All @@ -791,8 +792,8 @@ def process_l2qbxl(self, actx: PyOpenCLArrayContext, geo_data, l2qbxl_cost):
nqbx_centers_itgt_box = self.get_nqbx_centers_per_tgt_box(actx, geo_data)

# l2qbxl_cost_itgt_box = l2qbxl_cost[tree.box_levels[traversal.target_boxes]]
tgt_box_levels = tree.box_levels[traversal.target_boxes]
l2qbxl_cost_itgt_box = l2qbxl_cost[tgt_box_levels]
tgt_box_levels = actx.thaw(tree.box_levels)[traversal.target_boxes]
l2qbxl_cost_itgt_box = actx.thaw(l2qbxl_cost)[tgt_box_levels]

return nqbx_centers_itgt_box * l2qbxl_cost_itgt_box

Expand Down
4 changes: 2 additions & 2 deletions pytential/qbx/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_kernel(self):
for i in range(self.strength_count)]
+ [lp.GlobalArg(f"result_{i}", self.value_dtypes[i],
shape="ntargets_total", order="C")
for i in range(len(self.target_kernels))])
for i in range(self.nresults)])

loopy_knl = make_loopy_program([
"{[itgt_local]: 0 <= itgt_local < ntargets}",
Expand All @@ -96,7 +96,7 @@ def get_kernel(self):
simul_reduce(sum, isrc, pair_result_{i}) \
{{inames=itgt_local}}
""".format(i=iknl)
for iknl in range(len(self.target_kernels))]
for iknl in range(self.nresults)]
+ ["end"],
kernel_data=arguments,
name=self.name,
Expand Down
2 changes: 1 addition & 1 deletion pytential/qbx/fmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(self, tree_indep, geo_data, dtype,
actx = geo_data._setup_actx

translation_classes_data, _ = translation_classes_builder(actx)(
actx.queue, traversal, traversal.tree, is_translation_per_level=True)
actx, traversal, traversal.tree, is_translation_per_level=True)

super().__init__(
tree_indep, traversal,
Expand Down
10 changes: 5 additions & 5 deletions pytential/qbx/fmmlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ def full_output_zeros(self, actx: PyOpenCLArrayContext):
for k in self.tree_indep.outputs])

def reorder_sources(self, source_array):
if isinstance(source_array, cl.array.Array):
source_array = source_array.get()
if isinstance(source_array, self._setup_actx.array_types):
source_array = self._setup_actx.to_numpy(source_array)

return super().reorder_sources(source_array)

Expand Down Expand Up @@ -550,7 +550,7 @@ def translate_box_local_to_qbx_local(self, actx, local_exps):
@log_process(logger)
@return_timing_data
def eval_qbx_expansions(self, actx, qbx_expansions):
output = self.full_output_zeros(template_ary=qbx_expansions)
output = self.full_output_zeros(actx)

geo_data = self.geo_data
ctt = geo_data.center_to_tree_targets()
Expand Down Expand Up @@ -587,7 +587,7 @@ def eval_qbx_expansions(self, actx, qbx_expansions):
def eval_target_specific_qbx_locals(self, actx, src_weight_vecs):
src_weights, = src_weight_vecs
if not self.tree_indep.using_tsqbx:
return self.full_output_zeros(template_ary=src_weights)
return self.full_output_zeros(actx)

geo_data = self.geo_data
trav = geo_data.traversal()
Expand Down Expand Up @@ -632,7 +632,7 @@ def eval_target_specific_qbx_locals(self, actx, src_weight_vecs):
pot=pot,
grad=grad)

output = self.full_output_zeros(template_ary=src_weights)
output = self.full_output_zeros(actx)
self.add_potgrad_onto_output(output, slice(None), pot, grad)

return output
Expand Down
19 changes: 9 additions & 10 deletions pytential/qbx/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@
Subordinate data structures
^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: TargetInfo()
.. autoclass:: TargetInfo
.. autoclass:: CenterToTargetList()
.. autoclass:: CenterToTargetList
Enums of special values
^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -285,7 +285,6 @@ def make_container(

# {{{ geometry data

@dataclass_array_container
@dataclass(frozen=True)
class TargetInfo:
"""Describes the internal structure of the QBX FMM's list of :attr:`targets`.
Expand Down Expand Up @@ -573,7 +572,7 @@ def tree(self):

refine_weights.finish()

tree, _ = code_getter.build_tree()(actx.queue,
tree, _ = code_getter.build_tree()(actx,
particles=flatten(
quad_stage2_discr.nodes(), actx, leaf_class=DOFArray
),
Expand Down Expand Up @@ -605,13 +604,13 @@ def traversal(self, merge_close_lists=True):
"""

actx = self._setup_actx
trav, _ = self.code_getter.build_traversal(actx.queue, self.tree(),
trav, _ = self.code_getter.build_traversal(actx, self.tree(),
debug=self.debug,
_from_sep_smaller_min_nsources_cumul=(
self.lpot_source._from_sep_smaller_min_nsources_cumul))

if merge_close_lists and self.lpot_source._expansions_in_tree_have_extent:
trav = trav.merge_close_lists(actx.queue)
trav = trav.merge_close_lists(actx)

return actx.freeze(trav)

Expand Down Expand Up @@ -869,7 +868,7 @@ def non_qbx_box_target_lists(self):

tree = self.tree()
plfilt = self.code_getter.particle_list_filter()
result = plfilt.filter_target_lists_in_tree_order(actx.queue, tree, flags)
result = plfilt.filter_target_lists_in_tree_order(actx, tree, flags)

return actx.freeze(result)

Expand All @@ -879,8 +878,8 @@ def build_rotation_classes_lists(self):
trav = self.traversal()
tree = self.tree()

result = self.code_getter.rotation_classes_builder(actx.queue, trav, tree)
return result[0].get(queue=actx.queue)
result = self.code_getter.rotation_classes_builder(actx, trav, tree)
return actx.to_numpy(result[0])

@memoize_method
def m2l_rotation_lists(self):
Expand Down Expand Up @@ -926,7 +925,7 @@ def plot(self, draw_circles=False, draw_center_numbers=False,

global_flags = actx.to_numpy(self.global_qbx_flags())

tree = self.tree().get(queue=actx.queue)
tree = actx.to_numpy(self.tree())
from boxtree.visualization import TreePlotter
tp = TreePlotter(tree)
tp.draw_tree()
Expand Down
Loading

0 comments on commit 6737ff8

Please sign in to comment.