Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

callback improvements #1927

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
- Allow `SumFunction` with 1 item (#1857)
- Added PD3O algorithm (#1834)
- Added Barzilai-Borwein step size rule to work with GD, ISTA, FISTA (#1859)
- Added callback `optimisation.utilities.callbacks.EarlyStoppingObjectiveValue` which stops iterations if an algorithm objective changes less than a provided threshold (#1892)
- Added callback `optimisation.utilities.callbacks.CGLSEarlyStopping` which replicates the automatic behaviour of CGLS in CIL versions <=24. (#1892)
- Added callback `optimisation.utilities.callbacks.EarlyStopping` which stops iterations if an algorithm objective changes less than a provided threshold (#1892)
- Added callback `optimisation.utilities.callbacks.EarlyStoppingCGLS` which replicates the automatic behaviour of CGLS in CIL versions <=24. (#1892)
- Added `labels` module with `ImageDimension`, `AcquisitionDimension`, `AcquisitionType`, `AngleUnit`, `FillType` (#1692)
- Enhancements:
- Use ravel instead of flat in KullbackLeibler numba backend (#1874)
Expand Down
36 changes: 18 additions & 18 deletions Wrappers/Python/cil/optimisation/algorithms/CGLS.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,36 @@
from cil.optimisation.algorithms import Algorithm
import numpy
import logging
import warnings
import warnings

log = logging.getLogger(__name__)


class CGLS(Algorithm):

r'''Conjugate Gradient Least Squares (CGLS) algorithm

The Conjugate Gradient Least Squares (CGLS) algorithm is commonly used for solving large systems of linear equations, due to its fast convergence.

Problem:

.. math::

\min_x || A x - b ||^2_2


Parameters
------------
operator : Operator
Linear operator for the inverse problem
initial : (optional) DataContainer in the domain of the operator, default is a DataContainer filled with zeros.
Initial guess
data : DataContainer in the range of the operator
initial : (optional) DataContainer in the domain of the operator, default is a DataContainer filled with zeros.
Initial guess
data : DataContainer in the range of the operator
Acquired data to reconstruct

Note
-----
Passing tolerance directly to CGLS is being deprecated. Instead we recommend using the callback functionality: https://tomographicimaging.github.io/CIL/nightly/optimisation/#callbacks and in particular the CGLSEarlyStopping callback replicated the old behaviour.
Passing tolerance directly to CGLS is being deprecated. Instead we recommend using the callback functionality: https://tomographicimaging.github.io/CIL/nightly/optimisation/#callbacks and in particular the EarlyStoppingCGLS callback replicated the old behaviour.

Reference
---------
Expand All @@ -57,33 +57,33 @@ class CGLS(Algorithm):
def __init__(self, initial=None, operator=None, data=None, **kwargs):
'''initialisation of the algorithm
'''
#We are deprecating tolerance
self.tolerance=kwargs.pop("tolerance", None)
kwargs = kwargs.copy()
self.tolerance = kwargs.pop("tolerance", None)
if self.tolerance is not None:
warnings.warn( stacklevel=2, category=DeprecationWarning, message="Passing tolerance directly to CGLS is being deprecated. Instead we recommend using the callback functionality: https://tomographicimaging.github.io/CIL/nightly/optimisation/#callbacks and in particular the CGLSEarlyStopping callback replicated the old behaviour")
warnings.warn(stacklevel=2, category=DeprecationWarning, message="Use EarlyStoppingCGLS insead of `tolerance`. See https://tomographicimaging.github.io/CIL/nightly/optimisation/#callbacks")
else:
self.tolerance = 0

super(CGLS, self).__init__(**kwargs)

if initial is None and operator is not None:
initial = operator.domain_geometry().allocate(0)
if initial is not None and operator is not None and data is not None:
self.set_up(initial=initial, operator=operator, data=data)
self.set_up(initial=initial, operator=operator, data=data)

def set_up(self, initial, operator, data):
r'''Initialisation of the algorithm
Parameters
------------
operator : Operator
Linear operator for the inverse problem
initial : (optional) DataContainer in the domain of the operator, default is a DataContainer filled with zeros.
Initial guess
data : DataContainer in the range of the operator
initial : (optional) DataContainer in the domain of the operator, default is a DataContainer filled with zeros.
Initial guess
data : DataContainer in the range of the operator
Acquired data to reconstruct

'''

log.info("%s setting up", self.__class__.__name__)
self.x = initial.copy()
self.operator = operator
Expand Down
192 changes: 130 additions & 62 deletions Wrappers/Python/cil/optimisation/utilities/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,74 @@
import logging
from abc import ABC, abstractmethod
from functools import partialmethod
from pathlib import Path

from tqdm.auto import tqdm as tqdm_auto
from tqdm.std import tqdm as tqdm_std
import numpy as np
from cil.processors import Slicer
from cil.io import TIFFWriter

log = logging.getLogger(__name__)


class Callback(ABC):
'''Base Callback to inherit from for use in :code:`Algorithm.run(callbacks: list[Callback])`.
"""Base Callback to inherit from for use in :code:`Algorithm.run(callbacks: list[Callback])`.

Parameters
----------
verbose: int, choice of 0,1,2, default 1
verbose:
0=quiet, 1=info, 2=debug.
'''
def __init__(self, verbose=1):
interval:
Used by :code:`skip_iteration()`.
"""
def __init__(self, verbose: int = 1, interval: int = 1 << 31):
self.verbose = verbose
self.interval = interval

def skip_iteration(self, algorithm) -> bool:
"""Checks `min(self.interval, algorithm.update_objective_interval)`"""
interval = min(self.interval, algorithm.update_objective_interval)
return interval > 0 and algorithm.iteration % interval != 0 and algorithm.iteration != algorithm.max_iteration

@abstractmethod
def __call__(self, algorithm):
pass


class _OldCallback(Callback):
'''Converts an old-style :code:`def callback` to a new-style :code:`class Callback`.
"""Converts an old-style :code:`def callback` to a new-style :code:`class Callback`.

Parameters
----------
callback: :code:`callable(iteration, objective, x)`
'''
"""
def __init__(self, callback, *args, **kwargs):
super().__init__(*args, **kwargs)
self.func = callback

def __call__(self, algorithm):
if algorithm.update_objective_interval > 0 and algorithm.iteration % algorithm.update_objective_interval == 0:
if not self.skip_iteration(algorithm):
self.func(algorithm.iteration, algorithm.get_last_objective(return_all=self.verbose>=2), algorithm.x)


class ProgressCallback(Callback):
''':code:`tqdm`-based progress bar.
""":code:`tqdm`-based progress bar.

Parameters
----------
tqdm_class: default :code:`tqdm.auto.tqdm`
**tqdm_kwargs:
Passed to :code:`tqdm_class`.
'''
def __init__(self, verbose=1, tqdm_class=tqdm_auto, **tqdm_kwargs):
super().__init__(verbose=verbose)
"""
def __init__(self, verbose: int = 1, interval: int = 1 << 31, tqdm_class=tqdm_auto, **tqdm_kwargs):
super().__init__(verbose=verbose, interval=interval)
self.tqdm_class = tqdm_class
self.tqdm_kwargs = tqdm_kwargs
self._obj_len = 0 # number of objective updates

def __call__(self, algorithm):
if self.skip_iteration(algorithm):
return
if not hasattr(self, 'pbar'):
tqdm_kwargs = self.tqdm_kwargs
tqdm_kwargs.setdefault('total', algorithm.max_iteration)
Expand All @@ -67,16 +82,16 @@ def __call__(self, algorithm):


class _TqdmText(tqdm_std):
''':code:`tqdm`-based progress but text-only updates on separate lines.
""":code:`tqdm`-based progress but text-only updates on separate lines.

Parameters
----------
num_format: str
num_format:
Format spec for postfix numbers (i.e. objective values).
bar_format: str
bar_format:
Passed to :code:`tqdm`.
'''
def __init__(self, *args, num_format='+8.3e', bar_format="{n:>6d}/{total_fmt:<6} {rate_fmt:>9}{postfix}", **kwargs):
"""
def __init__(self, *args, num_format: str='+8.3e', bar_format: str="{n:>6d}/{total_fmt:<6} {rate_fmt:>9}{postfix}", **kwargs):
self.num_format = num_format
super().__init__(*args, bar_format=bar_format, mininterval=0, maxinterval=0, position=0, **kwargs)
self._instances.remove(self) # don't interfere with external progress bars
Expand Down Expand Up @@ -105,85 +120,138 @@ def display(self, *args, **kwargs):


class TextProgressCallback(ProgressCallback):
''':code:`ProgressCallback` but printed on separate lines to screen.
""":code:`ProgressCallback` but printed on separate lines to screen.

Parameters
----------
miniters: int, default :code:`Algorithm.update_objective_interval`
miniters: int, default :code:`min(Algorithm.update_objective_interval, Callback.interval)`
Number of algorithm iterations between screen prints.
'''
"""
__init__ = partialmethod(ProgressCallback.__init__, tqdm_class=_TqdmText)

def __call__(self, algorithm):
if not hasattr(self, 'pbar'):
self.tqdm_kwargs['miniters'] = min((
self.tqdm_kwargs.get('miniters', algorithm.update_objective_interval),
algorithm.update_objective_interval))
algorithm.update_objective_interval,
self.interval))
return super().__call__(algorithm)


class LogfileCallback(TextProgressCallback):
''':code:`TextProgressCallback` but to a file instead of screen.
""":code:`TextProgressCallback` but to a file instead of screen.

Parameters
----------
log_file: FileDescriptorOrPath
Passed to :code:`open()`.
mode: str
mode:
Passed to :code:`open()`.
'''
def __init__(self, log_file, mode='a', **kwargs):
**kwargs:
Passed to :code:`TextProgressCallback`.
"""
def __init__(self, log_file, mode: str='a', **kwargs):
self.fd = open(log_file, mode=mode)
super().__init__(file=self.fd, **kwargs)

class EarlyStoppingObjectiveValue(Callback):
'''Callback that stops iterations if the change in the objective value is less than a provided threshold value.


class EarlyStopping(Callback):
"""Terminates if objective value change < :code:`delta`.

Parameters
----------
threshold: float, default 1e-6
delta:
Usually a small number.
**kwargs:
Passed to :code:`Callback`.

Note
-----
This callback only compares the last two calculated objective values. If `update_objective_interval` is greater than 1, the objective value is not calculated at each iteration (which is the default behaviour), only every `update_objective_interval` iterations.

'''
def __init__(self, threshold=1e-6):
self.threshold=threshold

----
This callback only compares the last two calculated objective values.
If :code:`algorithm.update_objective_interval > 1`, the objective value is not calculated at each iteration.
"""
def __init__(self, delta: float=1e-6, **kwargs):
super().__init__(**kwargs)
self.threshold = delta

def __call__(self, algorithm):
if len(algorithm.loss)>=2:
if np.abs(algorithm.loss[-1]-algorithm.loss[-2])<self.threshold:
raise StopIteration

class CGLSEarlyStopping(Callback):
'''Callback to work with CGLS. It causes the algorithm to terminate if :math:`||A^T(Ax-b)||_2 < \epsilon||A^T(Ax_0-b)||_2` where `epsilon` is set to default as '1e-6', :math:`x` is the current iterate and :math:`x_0` is the initial value.
It will also terminate if the algorithm begins to diverge i.e. if :math:`||x||_2> \omega`, where `omega` is set to default as 1e6.
if not self.skip_iteration(algorithm) and len(loss := algorithm.loss) >= 2 and abs(loss[-1] - loss[-2]) < self.threshold:
raise StopIteration


class EarlyStoppingCGLS(Callback):
r"""Terminates CGLS if :math:`||A^T(Ax-b)||_2 < \epsilon||A^T(Ax_0-b)||_2`, where

- :math:`x` is the current iterate, and
- :math:`x_0` is the initial value.

It will also terminate if the algorithm begins to diverge i.e. if :math:`||x||_2> \omega`.

Parameters
----------
epsilon: float, default 1e-6
Usually a small number: the algorithm to terminate if :math:`||A^T(Ax-b)||_2 < \epsilon||A^T(Ax_0-b)||_2`
omega: float, default 1e6
Usually a large number: the algorithm will terminate if :math:`||x||_2> \omega`

epsilon:
Usually a small number.
omega:
Usually a large number.
**kwargs:
Passed to :code:`Callback`.

Note
-----
This callback is implemented to replicate the automatic behaviour of CGLS in CIL versions <=24. It also replicates the behaviour of https://web.stanford.edu/group/SOL/software/cgls/.
'''
def __init__(self, epsilon=1e-6, omega=1e6):
self.epsilon=epsilon
self.omega=omega


----
This callback is implemented to replicate the automatic behaviour of CGLS in CIL versions <=24.
It also replicates the behaviour of <https://web.stanford.edu/group/SOL/software/cgls/>.
"""
def __init__(self, epsilon: float=1e-6, omega: float=1e6, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
self.omega = omega

def __call__(self, algorithm):

if self.skip_iteration(algorithm):
return
if (algorithm.norms <= algorithm.norms0 * self.epsilon):
print('The norm of the residual is less than {} times the norm of the initial residual and so the algorithm is terminated'.format(self.epsilon))
if self.verbose:
log.info('StopIteration: (residual/initial) norm <= %d', self.epsilon)
raise StopIteration
self.normx = algorithm.x.norm()
if algorithm.normx >= self.omega:
print('The norm of the solution is greater than {} and so the algorithm is terminated'.format(self.omega))
if self.verbose:
log.info('StopIteration: solution norm >= %d', self.omega)
raise StopIteration




class TIFFLogger(Callback):
"""Saves solution as tiff files.

Parameters
----------
directory: FileDescriptorOrPath
Where to save the images.
stem:
The image filename pattern (without filetype suffix).
roi:
The region of interest to slice: `{'axis_name1':(start,stop,step), 'axis_name2':(start,stop,step)}`
- start: int or None: index, default 0.
- stop: int or None: index, default N.
- step: int or None: number of pixels to average together, default 1.
compression:
Passed to :code:`cil.io.TIFFWriter`.
**kwargs:
Passed to :code:`Callback`.
"""
def __init__(self, directory='.', stem='iter_{iteration:04d}', roi: dict|None=None, compression=None, **kwargs):
super().__init__(**kwargs)
self.file_name = f'{Path(directory) / stem}.tif'
self.slicer = Slicer(roi=roi) if roi is not None else None
self.compression = compression

def __call__(self, algorithm):
if self.skip_iteration(algorithm):
return
if self.slicer is None:
data = algorithm.solution
else:
self.slicer.set_input(algorithm.solution)
data = self.slicer.get_output()
w = TIFFWriter(data, file_name=self.file_name.format(iteration=algorithm.iteration), counter_offset=-1, compression=self.compression)
w.write()
Loading
Loading