Skip to content

Commit

Permalink
Update.
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoskelo committed Jul 30, 2024
1 parent 478fa3b commit aed5080
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 53 deletions.
5 changes: 4 additions & 1 deletion arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,15 @@
from .impl.pyopencl import PyOpenCLArrayContext
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
from .loopy import make_loopy_program
from .parameter_study import pack_for_parameter_study, unpack_parameter_study
from .pytest import (
PytestArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
pytest_generate_tests_for_array_contexts,
pytest_generate_tests_for_pyopencl_array_context,
)
from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag
from .parameter_study import pack_for_parameter_study, unpack_parameter_study


__all__ = (
"Array",
Expand Down Expand Up @@ -132,6 +133,7 @@
"multimap_reduce_array_container",
"multimapped_over_array_containers",
"outer",
"pack_for_parameter_study",
"pytest_generate_tests_for_array_contexts",
"pytest_generate_tests_for_pyopencl_array_context",
"rec_map_array_container",
Expand All @@ -145,6 +147,7 @@
"thaw",
"to_numpy",
"unflatten",
"unpack_parameter_study",
"with_array_context",
"with_container_arithmetic"
)
Expand Down
5 changes: 1 addition & 4 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,12 @@
import numpy as np

from pytools import memoize_method
from pytools.tag import Tag, ToTagSetConvertible, normalize_tags, UniqueTag
from pytools.tag import Tag, ToTagSetConvertible, UniqueTag as UniqueTag, normalize_tags

from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
from arraycontext.metadata import NameHint

from dataclasses import dataclass

if TYPE_CHECKING:
import pyopencl as cl
Expand Down Expand Up @@ -703,8 +702,6 @@ def clone(self):
# }}}




# {{{ PytatoJAXArrayContext

class PytatoJAXArrayContext(_BasePytatoArrayContext):
Expand Down
68 changes: 35 additions & 33 deletions arraycontext/parameter_study/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from future import __annotations__


"""
.. currentmodule:: arraycontext
Expand Down Expand Up @@ -48,45 +46,51 @@
THE SOFTWARE.
"""

import sys
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Tuple,
Type,
)

import numpy as np

import loopy as lp
from pytato.array import (Array, make_placeholder as make_placeholder,
make_dict_of_named_arrays)

from pytato.transform.parameter_study import ParameterStudyAxisTag
from pytato.array import (
Array,
AxesT,
ShapeType,
make_dict_of_named_arrays,
make_placeholder as make_placeholder,
)
from pytato.transform.parameter_study import (
ExpansionMapper,
ParameterStudyAxisTag,
)
from pytools.tag import Tag, UniqueTag as UniqueTag

from arraycontext.context import ArrayContext
from arraycontext.container import ArrayContainer, is_array_container_type
from arraycontext.container import (
ArrayContainer as ArrayContainer,
is_array_container_type,
)
from arraycontext.container.traversal import rec_keyed_map_array_container
from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext
from arraycontext.impl.pytato.compile import (LazilyPyOpenCLCompilingFunctionCaller,
_to_input_for_compiled)
from arraycontext.context import ArrayContext
from arraycontext.impl.pytato import (
PytatoPyOpenCLArrayContext,
_get_arg_id_to_arg_and_arg_id_to_descr,
)
from arraycontext.impl.pytato.compile import (
LazilyPyOpenCLCompilingFunctionCaller,
LeafArrayDescriptor,
_ary_container_key_stringifier,
_to_input_for_compiled,
)


ArraysT = Tuple[Array, ...]
StudiesT = Tuple[ParameterStudyAxisTag, ...]
ArraysT = tuple[Array, ...]
StudiesT = tuple[ParameterStudyAxisTag, ...]
ParamStudyTagT = Type[ParameterStudyAxisTag]

if TYPE_CHECKING:
import pyopencl as cl
import pytato as pytato

if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
import pyopencl as cl

import logging

Expand Down Expand Up @@ -168,7 +172,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
self.actx._compile_trace_callback(self.f, "post_trace", output_template)

if (not (is_array_container_type(output_template.__class__)
or isinstance(output_template, pt.Array))):
or isinstance(output_template, Array))):
# TODO: We could possibly just short-circuit this interface if the
# returned type is a scalar. Not sure if it's worth it though.
raise NotImplementedError(
Expand All @@ -185,9 +189,7 @@ def _as_dict_of_named_arrays(keys, ary):
rec_keyed_map_array_container(_as_dict_of_named_arrays,
output_template)

input_shapes = {}
input_axes = {}
placeholder_name_to_parameter_studies: Dict[str, StudiesT] = {}
placeholder_name_to_parameter_studies: dict[str, StudiesT] = {}
for key, val in arg_id_to_descr.items():
if isinstance(val, LeafArrayDescriptor):
name = input_id_to_name_in_program[key]
Expand Down Expand Up @@ -240,7 +242,7 @@ def _cut_to_single_instance_size(name, arg) -> Array:
update_axes = (*update_axes, arg.axes[i],)
newshape = (*newshape, arg.shape[i])

update_tags: FrozenSet[Tag] = arg.tags
update_tags: frozenset[Tag] = arg.tags

return make_placeholder(name, newshape, arg.dtype, axes=update_axes,
tags=update_tags)
Expand Down Expand Up @@ -282,7 +284,7 @@ def pack_for_parameter_study(actx: ArrayContext,
study_name_tag_type: ParamStudyTagT,
*args: Array) -> Array:
"""
Args is a list of realized input data that needs to be packed
Args is a list of realized input data that needs to be packed
for a parameter study or uncertainty quantification.
We assume that each input data set has the same shape and
Expand All @@ -301,7 +303,7 @@ def pack_for_parameter_study(actx: ArrayContext,

def unpack_parameter_study(data: Array,
study_name_tag_type: ParamStudyTagT) -> Mapping[int,
List[Array]]:
list[Array]]:
"""
Split the data array along the axes which vary according to
a ParameterStudyAxisTag whose name tag is an instance study_name_tag_type.
Expand All @@ -311,7 +313,7 @@ def unpack_parameter_study(data: Array,
"""

ndim: int = len(data.shape)
out: Dict[int, List[Array]] = {}
out: dict[int, list[Array]] = {}

study_count = 0
for i in range(ndim):
Expand All @@ -320,7 +322,7 @@ def unpack_parameter_study(data: Array,
# Now we need to split this data.
breakpoint()
for j in range(data.shape[i]):
tmp: List[Any] = [slice(None)] * ndim
tmp: list[Any] = [slice(None)] * ndim
tmp[i] = j
the_slice = tuple(tmp)
# Needs to be a tuple of slices not list of slices.
Expand Down
2 changes: 2 additions & 0 deletions arraycontext/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self, device):

@classmethod
def is_available(cls) -> bool:
return False
try:
import pyopencl # noqa: F401
return True
Expand Down Expand Up @@ -133,6 +134,7 @@ class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars(
class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory):
@classmethod
def is_available(cls) -> bool:
return True
try:
import pyopencl # noqa: F401
import pytato # noqa: F401
Expand Down
22 changes: 11 additions & 11 deletions examples/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
queue = cl.CommandQueue(ctx)
actx = ParamStudyPytatoPyOpenCLArrayContext(queue)


@dataclass(frozen=True)
class ParamStudy1(ParameterStudyAxisTag):
"""
Expand All @@ -27,7 +28,6 @@ class ParamStudy1(ParameterStudyAxisTag):

def test_one_time_step_advection():

import numpy as np
seed = 12345
rng = np.random.default_rng(seed)

Expand All @@ -36,46 +36,46 @@ def test_one_time_step_advection():
x1 = actx.from_numpy(rng.random(base_shape))
x2 = actx.from_numpy(rng.random(base_shape))
x3 = actx.from_numpy(rng.random(base_shape))


speed_shape = (1,)
y0 = actx.from_numpy(rng.random(speed_shape))
y1 = actx.from_numpy(rng.random(speed_shape))
y2 = actx.from_numpy(rng.random(speed_shape))
y3 = actx.from_numpy(rng.random(speed_shape))


ht = 0.0001
hx = 0.005
inds = np.arange(base_shape, dtype=int)
Kp1 = actx.from_numpy(np.roll(inds, -1))
Km1 = actx.from_numpy(np.roll(inds, 1))
kp1 = actx.from_numpy(np.roll(inds, -1))
km1 = actx.from_numpy(np.roll(inds, 1))

def rhs(fields, wave_speed):
# 2nd order in space finite difference
return fields + wave_speed * (-1) * (ht / (2 * hx)) * \
(fields[Kp1] - fields[Km1])
(fields[kp1] - fields[km1])

pack_x = pack_for_parameter_study(actx, ParamStudy1, (4,), x0, x1, x2, x3)
breakpoint()
assert pack_x.shape == (75,4)
assert pack_x.shape == (75, 4)

pack_y = pack_for_parameter_study(actx, ParamStudy1, (4,), y0,y1, y2,y3)
pack_y = pack_for_parameter_study(actx, ParamStudy1, (4,), y0, y1, y2, y3)
breakpoint()
assert pack_y.shape == (1,4)
assert pack_y.shape == (1, 4)

compiled_rhs = actx.compile(rhs)
breakpoint()

output = compiled_rhs(pack_x, pack_y)
breakpoint()
assert output.shape(75,4)
assert output.shape(75, 4)

output_x = unpack_parameter_study(output, ParamStudy1)
assert len(output_x) == 1 # Only 1 study associated with this variable.
assert len(output_x[0]) == 4 # 4 inputs for the parameter study.
assert len(output_x[0]) == 4 # 4 inputs for the parameter study.

print("All checks passed")

# Call it.


test_one_time_step_advection()
4 changes: 2 additions & 2 deletions examples/parameter_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import pyopencl as cl

from arraycontext.parameter_study import (
ParameterStudyAxisTag,
ParamStudyPytatoPyOpenCLArrayContext,
pack_for_parameter_study,
unpack_parameter_study,
ParamStudyPytatoPyOpenCLArrayContext,
ParameterStudyAxisTag,
)


Expand Down
4 changes: 2 additions & 2 deletions test/test_pytato_parameter_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
pytest_generate_tests_for_array_contexts,
)
from arraycontext.parameter_study import (
pack_for_parameter_study,
unpack_parameter_study,
ParameterStudyAxisTag,
ParamStudyPytatoPyOpenCLArrayContext,
pack_for_parameter_study,
unpack_parameter_study,
)
from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory

Expand Down

0 comments on commit aed5080

Please sign in to comment.