Skip to content

Commit

Permalink
rip out timing collection
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Aug 2, 2023
1 parent f9a42a3 commit adf6d2a
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 208 deletions.
6 changes: 2 additions & 4 deletions examples/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def calibrate_cost_model(ctx):

for _ in range(RUNS):
timing_data = {}
bound_op.eval({"sigma": sigma}, array_context=actx,
timing_data=timing_data)
bound_op.eval({"sigma": sigma}, array_context=actx)

model_results.append(modeled_cost)
timing_results.append(timing_data)
Expand Down Expand Up @@ -175,8 +174,7 @@ def test_cost_model(ctx, calibration_params):
temp_timing_results = []
for _ in range(RUNS):
timing_data = {}
bound_op.eval({"sigma": sigma},
array_context=actx, timing_data=timing_data)
bound_op.eval({"sigma": sigma}, array_context=actx)
temp_timing_results.append(one(timing_data.values()))

timing_result = {}
Expand Down
50 changes: 10 additions & 40 deletions pytential/qbx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,11 @@ def op_group_features(self, expr):

# {{{ internal functionality for execution

def exec_compute_potential_insn(self, actx, insn, bound_expr, evaluate,
return_timing_data):
def exec_compute_potential_insn(self, actx, insn, bound_expr, evaluate):
extra_args = {}

if self.fmm_level_to_order is False:
func = self.exec_compute_potential_insn_direct
extra_args["return_timing_data"] = return_timing_data

else:
func = self.exec_compute_potential_insn_fmm
Expand All @@ -440,11 +438,7 @@ def drive_fmm(
actx, wrangler, strengths, geo_data, kernel, kernel_arguments):
del geo_data, kernel, kernel_arguments
from pytential.qbx.fmm import drive_fmm
if return_timing_data:
timing_data = {}
else:
timing_data = None
return drive_fmm(actx, wrangler, strengths, timing_data), timing_data
return drive_fmm(actx, wrangler, strengths)

extra_args["fmm_driver"] = drive_fmm

Expand Down Expand Up @@ -473,25 +467,13 @@ def cost_model_compute_potential_insn(self, actx, insn, bound_expr, evaluate,

def drive_cost_model(
actx, wrangler, strengths, geo_data, kernel, kernel_arguments):

if per_box:
cost_model_result, metadata = self.cost_model.qbx_cost_per_box(
actx, geo_data, kernel, kernel_arguments,
calibration_params
)
else:
cost_model_result, metadata = self.cost_model.qbx_cost_per_stage(
actx, geo_data, kernel, kernel_arguments,
calibration_params
)

from pytools.obj_array import obj_array_vectorize
from functools import partial
return (
obj_array_vectorize(
partial(wrangler.finalize_potentials, actx),
wrangler.full_output_zeros(actx)),
(cost_model_result, metadata))
wrangler.full_output_zeros(actx))
)

return self._dispatch_compute_potential_insn(
actx, insn, bound_expr, evaluate,
Expand Down Expand Up @@ -595,11 +577,8 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
"""
:arg fmm_driver: A function that accepts four arguments:
*wrangler*, *strength*, *geo_data*, *kernel*, *kernel_arguments*
:returns: a tuple ``(assignments, extra_outputs)``, where *assignments*
is a list of tuples containing pairs ``(name, value)`` representing
assignments to be performed in the evaluation context.
*extra_outputs* is data that *fmm_driver* may return
(such as timing data), passed through unmodified.
:returns: a list of assignments containing pairs ``(name, value)``
representing assignments to be performed in the evaluation context.
"""
target_name_and_side_to_number, target_discrs_and_qbx_sides = (
self.get_target_discrs_and_qbx_sides(insn, bound_expr))
Expand Down Expand Up @@ -663,7 +642,7 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
# }}}

# Execute global QBX.
all_potentials_on_every_target, extra_outputs = (
all_potentials_on_every_target = (
fmm_driver(
actx, wrangler, flat_strengths, geo_data,
base_kernel, kernel_extra_kwargs))
Expand All @@ -686,7 +665,7 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,

results.append((o.name, result))

return results, extra_outputs
return results

# }}}

Expand Down Expand Up @@ -758,18 +737,10 @@ def get_qbx_target_numberer(self, dtype):
*count = item;
""")

def exec_compute_potential_insn_direct(self, actx, insn, bound_expr, evaluate,
return_timing_data):
def exec_compute_potential_insn_direct(self, actx, insn, bound_expr, evaluate):
from pytential import bind, sym
from meshmode.discretization import Discretization

if return_timing_data:
from pytential.source import UnableToCollectTimingData
from warnings import warn
warn(
"Timing data collection not supported.",
category=UnableToCollectTimingData)

# {{{ evaluate and flatten inputs

@memoize_in(bound_expr.places,
Expand Down Expand Up @@ -947,8 +918,7 @@ def _flat_centers(dofdesc, qbx_forced_limit):

# }}}

timing_data = {}
return results, timing_data
return results

# }}}

Expand Down
Loading

0 comments on commit adf6d2a

Please sign in to comment.