From 6ad30d698dd655902efab7f9dee5f3c76592b942 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 10 Oct 2023 15:34:08 +0100 Subject: [PATCH] add unit test for _proximal_step_numpy for float, DataContainer and ndarray update docstring update changelog --- CHANGELOG.md | 3 ++ .../optimisation/functions/MixedL21Norm.py | 4 +- Wrappers/Python/test/test_functions.py | 47 ++++++++++++++++++- 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 349e90a9c7..12ed393a0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +* xx.x.x + - bugfix + - proximal of MixedL21Norm with numpy backend * 23.1.0 - Fix bug in IndicatorBox proximal_conjugate diff --git a/Wrappers/Python/cil/optimisation/functions/MixedL21Norm.py b/Wrappers/Python/cil/optimisation/functions/MixedL21Norm.py index 7cbe332d2f..bb9d38be15 100644 --- a/Wrappers/Python/cil/optimisation/functions/MixedL21Norm.py +++ b/Wrappers/Python/cil/optimisation/functions/MixedL21Norm.py @@ -62,8 +62,8 @@ def _proximal_step_numpy(arr, tau): Parameters: ----------- - tmp : DataContainer/ numpy array, best if contiguous memory. - tau: float or DataContainer + arr : DataContainer/ numpy array, best if contiguous memory. + tau: float, numpy array or DataContainer Returns: -------- diff --git a/Wrappers/Python/test/test_functions.py b/Wrappers/Python/test/test_functions.py index 9f05973824..56b8e6da4c 100644 --- a/Wrappers/Python/test/test_functions.py +++ b/Wrappers/Python/test/test_functions.py @@ -408,9 +408,52 @@ def test_MixedL21Norm_step(self): # check they are the same np.testing.assert_allclose(res1, res2.as_array(), atol=1e-5, rtol=1e-6) - def test_MixedL21Norm_proximal_numpy(self): - assert True + def test_MixedL21Norm_proximal_step_numpy_float(self): + from cil.optimisation.functions.MixedL21Norm import _proximal_step_numpy + from cil.framework import ImageGeometry + + tau = 1.1 + + ig = ImageGeometry(2,3,4) + tmp = ig.allocate(1) + a = _proximal_step_numpy(tmp, tau) + + b = _proximal_step_numpy(tmp, -tau) + + np.testing.assert_allclose(a.as_array(), b.as_array()) + + def test_MixedL21Norm_proximal_step_numpy_dc(self): + from cil.optimisation.functions.MixedL21Norm import _proximal_step_numpy + from cil.framework import ImageGeometry + + ig = ImageGeometry(2,3,4) + tmp = ig.allocate(1) + tau = ig.allocate(2) + a = _proximal_step_numpy(tmp, tau) + + tau *= -1 + b = _proximal_step_numpy(tmp, tau) + + np.testing.assert_allclose(a.as_array(), b.as_array()) + + def test_MixedL21Norm_proximal_step_numpy_ndarray(self): + from cil.optimisation.functions.MixedL21Norm import _proximal_step_numpy + from cil.framework import ImageGeometry + + + ig = ImageGeometry(2,3,4) + tmp = ig.allocate(1) + tau = ig.allocate(2) + tauarr = tau.as_array() + a = _proximal_step_numpy(tmp, tauarr) + + tauarr *= -1 + b = _proximal_step_numpy(tmp, tauarr) + + np.testing.assert_allclose(a.as_array(), b.as_array()) + + def test_smoothL21Norm(self): ig = ImageGeometry(4, 5) bg = BlockGeometry(ig, ig)