Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix indicator box #1486

Merged
merged 13 commits into from
Aug 2, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

* x.x.x
- fix bug in IndicatorBox proximal_conjugate
- allow CCPi Regulariser functions for not CIL object
- Add norm for CompositionOperator.
- Refactor SIRT algorithm to make it more computationally and memory efficient
Expand Down
30 changes: 0 additions & 30 deletions Wrappers/Python/cil/optimisation/functions/IndicatorBox.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,36 +142,6 @@ def __call__(self, x):
return self.evaluate(x)
return 0.0

def proximal_conjugate(self, x, tau, out=None):
r'''Proximal operator of the convex conjugate of IndicatorBox at x:

.. math:: prox_{\tau * f^{*}}(x)

Parameters
----------
x : DataContainer
Input to the proximal operator
tau : float
Step size. Notice it is ignored in IndicatorBox, see ``proximal`` for details
out : DataContainer, optional
Output of the proximal operator. If not provided, a new DataContainer is created.

'''

# x - tau * self.proximal(x/tau, tau)
should_return = False

if out is None:
out = self.proximal(x, tau)
should_return = True
else:
self.proximal(x, tau, out=out)

out.sapyb(-1., x, 1., out=out)

if should_return:
return out

def proximal(self, x, tau, out=None):
r'''Proximal operator of IndicatorBox at x

Expand Down
46 changes: 46 additions & 0 deletions Wrappers/Python/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,52 @@ def test_Lipschitz4(self):
f4 = -2 * f2
assert f4.L == 2 * f2.L

def test_proximal_conjugate(self):
from cil.framework import AcquisitionGeometry, BlockGeometry
ag = AcquisitionGeometry.create_Parallel2D()
angles = np.linspace(0, 360, 10, dtype=np.float32)

#default
ag.set_angles(angles)
ag.set_panel(5)

ig = ag.get_ImageGeometry()
bg = BlockGeometry(ig, ig)

b = ag.allocate('random', seed=2)

func_geom_test_list = [
(IndicatorBox(), ag),
(KullbackLeibler(b=b, backend='numba'), ag),
(KullbackLeibler(b=b, backend='numpy'), ag),
(L1Norm(), ag),
(L2NormSquared(), ag),
(MixedL21Norm(), bg),
(TotalVariation(backend='c'), ig),
(TotalVariation(backend='numpy'), ig),
]

for func, geom in func_geom_test_list:
self.proximal_conjugate_test(func, geom)

def proximal_conjugate_test(self, function, geom):
x = geom.allocate('random', seed=1)
tau = 1.0
f = Function()
f.proximal = function.proximal

a = function.proximal_conjugate(x, tau)
b = f.proximal_conjugate(x, tau)

# handle the case of MixedL21Norm
if isinstance(a, BlockDataContainer):
for xa,xb in zip(a,b):
np.testing.assert_allclose(xa.as_array(), xb.as_array(),
rtol=1e-5, atol=1e-5)
else:
np.testing.assert_allclose(a.as_array(), b.as_array(),
rtol=1e-5, atol=1e-5)


class TestTotalVariation(unittest.TestCase):

Expand Down