Skip to content

Commit

Permalink
add unit test for _proximal_step_numpy for float, DataContainer and n…
Browse files Browse the repository at this point in the history
…darray

update docstring
  • Loading branch information
paskino committed Oct 17, 2023
1 parent ca73865 commit b2dd037
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
4 changes: 2 additions & 2 deletions Wrappers/Python/cil/optimisation/functions/MixedL21Norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def _proximal_step_numpy(arr, tau):
Parameters:
-----------
tmp : DataContainer/ numpy array, best if contiguous memory.
tau: float or DataContainer
arr : DataContainer/ numpy array, best if contiguous memory.
tau: float, numpy array or DataContainer
Returns:
--------
Expand Down
47 changes: 45 additions & 2 deletions Wrappers/Python/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,52 @@ def test_MixedL21Norm_step(self):
# check they are the same
np.testing.assert_allclose(res1, res2.as_array(), atol=1e-5, rtol=1e-6)

def test_MixedL21Norm_proximal_numpy(self):
assert True
def test_MixedL21Norm_proximal_step_numpy_float(self):
from cil.optimisation.functions.MixedL21Norm import _proximal_step_numpy
from cil.framework import ImageGeometry

tau = 1.1

ig = ImageGeometry(2,3,4)
tmp = ig.allocate(1)
a = _proximal_step_numpy(tmp, tau)

b = _proximal_step_numpy(tmp, -tau)

np.testing.assert_allclose(a.as_array(), b.as_array())

def test_MixedL21Norm_proximal_step_numpy_dc(self):
from cil.optimisation.functions.MixedL21Norm import _proximal_step_numpy
from cil.framework import ImageGeometry


ig = ImageGeometry(2,3,4)
tmp = ig.allocate(1)
tau = ig.allocate(2)
a = _proximal_step_numpy(tmp, tau)

tau *= -1
b = _proximal_step_numpy(tmp, tau)

np.testing.assert_allclose(a.as_array(), b.as_array())

def test_MixedL21Norm_proximal_step_numpy_ndarray(self):
from cil.optimisation.functions.MixedL21Norm import _proximal_step_numpy
from cil.framework import ImageGeometry


ig = ImageGeometry(2,3,4)
tmp = ig.allocate(1)
tau = ig.allocate(2)
tauarr = tau.as_array()
a = _proximal_step_numpy(tmp, tauarr)

tauarr *= -1
b = _proximal_step_numpy(tmp, tauarr)

np.testing.assert_allclose(a.as_array(), b.as_array())


def test_smoothL21Norm(self):
ig = ImageGeometry(4, 5)
bg = BlockGeometry(ig, ig)
Expand Down

0 comments on commit b2dd037

Please sign in to comment.