Skip to content

Commit

Permalink
Add SplitPytatoArrayContext to the test suite
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd committed Jan 24, 2023
1 parent d8787e9 commit 11583d7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
13 changes: 12 additions & 1 deletion arraycontext/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,22 @@ def __call__(self):
return self.actx_class(queue, allocator=alloc)

def __str__(self):
return ("<PytatoPyOpenCLArrayContext for <pyopencl.Device '%s' on '%s'>>" %
return ("<%s for <pyopencl.Device '%s' on '%s'>>" %
(
self.__class__.__name__,
self.device.name.strip(),
self.device.platform.name.strip()))


class _PytestSplitPytatoPyOpenCLArrayContextFactory(
_PytestPytatoPyOpenCLArrayContextFactory):
@property
def actx_class(self):
from arraycontext.impl.pytato.split_actx import (
SplitPytatoPyOpenCLArrayContext)
return SplitPytatoPyOpenCLArrayContext


class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory):
def __init__(self, *args, **kwargs):
pass
Expand Down Expand Up @@ -231,6 +241,7 @@ def __str__(self):
_PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars,
"pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory,
"pytato:jax": _PytestPytatoJaxArrayContextFactory,
"pytato:split": _PytestSplitPytatoPyOpenCLArrayContextFactory,
"eagerjax": _PytestEagerJaxArrayContextFactory,
}

Expand Down
4 changes: 3 additions & 1 deletion test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
serialize_container, tag_axes, with_array_context, with_container_arithmetic)
from arraycontext.pytest import (
_PytestEagerJaxArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass,
_PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory)
_PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory,
_PytestSplitPytatoPyOpenCLArrayContextFactory)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -84,6 +85,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory(
_PytatoPyOpenCLArrayContextForTestsFactory,
_PytestEagerJaxArrayContextFactory,
_PytestPytatoJaxArrayContextFactory,
_PytestSplitPytatoPyOpenCLArrayContextFactory,
])


Expand Down

0 comments on commit 11583d7

Please sign in to comment.