-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
728c3d2
commit 98c560b
Showing
1 changed file
with
119 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import pytest | ||
pytest.importorskip("pyopencl") | ||
|
||
import sys | ||
import petsc4py | ||
petsc4py.init(sys.argv | ||
+ "-viennacl_backend opencl".split() | ||
+ "-viennacl_opencl_device_type cpu".split()) | ||
from pyop2 import op2 | ||
import pyopencl.array as cla | ||
import numpy as np | ||
import loopy as lp | ||
lp.set_caching_enabled(False) | ||
|
||
|
||
def allclose(a, b, rtol=1e-05, atol=1e-08): | ||
""" | ||
Prefer this routine over np.allclose(...) to allow pycuda/pyopencl arrays | ||
""" | ||
return bool(abs(a - b) < (atol + rtol * abs(b))) | ||
|
||
|
||
def pytest_generate_tests(metafunc): | ||
if "backend" in metafunc.fixturenames: | ||
from pyop2.backends.opencl import opencl_backend | ||
metafunc.parametrize("backend", [opencl_backend]) | ||
|
||
|
||
def test_new_backend_raises_not_implemented_error(): | ||
from pyop2.backends import AbstractComputeBackend | ||
unimplemented_backend = AbstractComputeBackend() | ||
|
||
attrs = ["GlobalKernel", "Parloop", "Set", "ExtrudedSet", "MixedSet", | ||
"Subset", "DataSet", "MixedDataSet", "Map", "MixedMap", "Dat", | ||
"MixedDat", "DatView", "Mat", "Global", "GlobalDataSet", | ||
"PETScVecType"] | ||
|
||
for attr in attrs: | ||
with pytest.raises(NotImplementedError): | ||
getattr(unimplemented_backend, attr) | ||
|
||
|
||
def test_dat_with_petscvec_representation(backend): | ||
op2.set_offloading_backend(backend) | ||
|
||
nelems = 9 | ||
data = np.random.rand(nelems) | ||
set_ = op2.compute_backend.Set(nelems) | ||
dset = op2.compute_backend.DataSet(set_, 1) | ||
dat = op2.compute_backend.Dat(dset, data.copy()) | ||
|
||
assert isinstance(dat.data_ro, np.ndarray) | ||
dat.data[:] *= 3 | ||
|
||
with op2.offloading(): | ||
assert isinstance(dat.data_ro, cla.Array) | ||
dat.data[:] *= 2 | ||
|
||
assert isinstance(dat.data_ro, np.ndarray) | ||
np.testing.assert_allclose(dat.data_ro, 6*data) | ||
|
||
|
||
def test_dat_not_as_petscvec(backend): | ||
op2.set_offloading_backend(backend) | ||
|
||
nelems = 9 | ||
data = np.random.randint(low=-10, high=10, | ||
size=nelems, | ||
dtype=np.int64) | ||
set_ = op2.compute_backend.Set(nelems) | ||
dset = op2.compute_backend.DataSet(set_, 1) | ||
dat = op2.compute_backend.Dat(dset, data.copy()) | ||
|
||
assert isinstance(dat.data_ro, np.ndarray) | ||
dat.data[:] *= 3 | ||
|
||
with op2.offloading(): | ||
assert isinstance(dat.data_ro, cla.Array) | ||
dat.data[:] *= 2 | ||
|
||
assert isinstance(dat.data_ro, np.ndarray) | ||
np.testing.assert_allclose(dat.data_ro, 6*data) | ||
|
||
|
||
def test_global_reductions(backend): | ||
op2.set_offloading_backend(backend) | ||
|
||
sum_knl = lp.make_function( | ||
"{ : }", | ||
""" | ||
g[0] = g[0] + x[0] | ||
""", | ||
[lp.GlobalArg("g,x", | ||
dtype="float64", | ||
shape=lp.auto, | ||
)], | ||
name="sum_knl", | ||
target=lp.CWithGNULibcTarget(), | ||
lang_version=(2018, 2)) | ||
|
||
rng = np.random.default_rng() | ||
nelems = 4_000 | ||
data_to_sum = rng.random(nelems) | ||
|
||
with op2.offloading(): | ||
|
||
set_ = op2.compute_backend.Set(4_000) | ||
dset = op2.compute_backend.DataSet(set_, 1) | ||
|
||
dat = op2.compute_backend.Dat(dset, data_to_sum.copy()) | ||
glob = op2.compute_backend.Global(1, 0, np.float64, "g") | ||
|
||
op2.parloop(op2.Kernel(sum_knl, "sum_knl"), | ||
set_, | ||
glob(op2.INC), | ||
dat(op2.READ)) | ||
|
||
assert isinstance(glob.data_ro, cla.Array) | ||
assert allclose(glob.data_ro[0], data_to_sum.sum()) |