diff --git a/CHANGELOG.md b/CHANGELOG.md index 0eb818e67e..fc0be4e350 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,8 +28,10 @@ - Make Binner accept accelerated=False (#1887) - Added checks on memory allocations within `FiniteDifferenceLibrary.cpp` and verified the status of the return in `GradientOperator` (#1929) - Build release version of `cilacc.dll` for Windows. Previously was defaulting to the debug build (#1928) + - Armijo step size rule now by default initialises the search for a step size from the previously calculated step size (#1934) - Changes that break backwards compatibility: - CGLS will no longer automatically stop iterations once a default tolerance is reached. The option to pass `tolerance` will be deprecated to be replaced by `optimisation.utilities.callbacks` (#1892) + * 24.1.0 - New Features: @@ -54,6 +56,7 @@ - BlockOperator that would return a BlockDataContainer of shape (1,1) now returns the appropriate DataContainer. BlockDataContainer direct and adjoint methods accept DataContainer as parameter (#1802). - BlurringOperator: remove check for geometry class (old SIRF integration bug) (#1807) - The `ZeroFunction` and `ConstantFunction` now have a Lipschitz constant of 1. (#1768) + - Update dataexample remote data download to work with windows and use zenodo_get for data download (#1774) - Changes that break backwards compatibility: - Merged the files `BlockGeometry.py` and `BlockDataContainer.py` in `framework` to one file `block.py`. Please use `from cil.framework import BlockGeometry, BlockDataContainer` as before (#1799) - Bug fix in `FGP_TV` function to set the default behaviour not to enforce non-negativity (#1826). diff --git a/Wrappers/Python/cil/optimisation/algorithms/FISTA.py b/Wrappers/Python/cil/optimisation/algorithms/FISTA.py index b507e42f04..05143e8157 100644 --- a/Wrappers/Python/cil/optimisation/algorithms/FISTA.py +++ b/Wrappers/Python/cil/optimisation/algorithms/FISTA.py @@ -213,8 +213,19 @@ def update_objective(self): .. math:: f(x) + g(x) """ - self.loss.append(self.f(self.x_old) + self.g(self.x_old)) + self.loss.append(self.calculate_objective_function_at_point(self.x_old)) + def calculate_objective_function_at_point(self, x): + """ Calculates the objective at a given point x + + .. math:: f(x) + g(x) + + Parameters + ---------- + x : DataContainer + + """ + return self.f(x) + self.g(x) class FISTA(ISTA): diff --git a/Wrappers/Python/cil/optimisation/algorithms/GD.py b/Wrappers/Python/cil/optimisation/algorithms/GD.py index 22e1adcc78..9d3fdbee70 100644 --- a/Wrappers/Python/cil/optimisation/algorithms/GD.py +++ b/Wrappers/Python/cil/optimisation/algorithms/GD.py @@ -84,7 +84,7 @@ def set_up(self, initial, objective_function, step_size, preconditioner): log.info("%s setting up", self.__class__.__name__) self.x = initial.copy() - self.objective_function = objective_function + self._objective_function = objective_function if step_size is None: self.step_size_rule = ArmijoStepSizeRule( @@ -106,7 +106,7 @@ def set_up(self, initial, objective_function, step_size, preconditioner): def update(self): '''Performs a single iteration of the gradient descent algorithm''' - self.objective_function.gradient(self.x, out=self.gradient_update) + self._objective_function.gradient(self.x, out=self.gradient_update) if self.preconditioner is not None: self.preconditioner.apply( @@ -117,7 +117,7 @@ def update(self): self.x.sapyb(1.0, self.gradient_update, -step_size, out=self.x) def update_objective(self): - self.loss.append(self.objective_function(self.solution)) + self.loss.append(self._objective_function(self.solution)) def should_stop(self): '''Stopping criterion for the gradient descent algorithm ''' @@ -132,3 +132,20 @@ def step_size(self): else: raise TypeError( "There is not a constant step size, it is set by a step-size rule") + + def calculate_objective_function_at_point(self, x): + """ Calculates the objective at a given point x + + .. math:: f(x) + g(x) + + Parameters + ---------- + x : DataContainer + + """ + return self._objective_function(x) + + @property + def objective_function(self): + warn('The attribute `objective_function` will be deprecated in the future. Please use `calculate_objective_function_at_point` instead.', DeprecationWarning, stacklevel=2) + return self._objective_function \ No newline at end of file diff --git a/Wrappers/Python/cil/optimisation/utilities/StepSizeMethods.py b/Wrappers/Python/cil/optimisation/utilities/StepSizeMethods.py index d7a5c47942..a680bd24d7 100644 --- a/Wrappers/Python/cil/optimisation/utilities/StepSizeMethods.py +++ b/Wrappers/Python/cil/optimisation/utilities/StepSizeMethods.py @@ -19,6 +19,9 @@ from abc import ABC, abstractmethod import numpy from numbers import Number +import logging + +log = logging.getLogger(__name__) class StepSizeRule(ABC): """ @@ -82,6 +85,9 @@ class ArmijoStepSizeRule(StepSizeRule): The amount the step_size is reduced if the criterion is not met max_iterations: integer, optional, default is numpy.ceil (2 * numpy.log10(alpha) / numpy.log10(2)) The maximum number of iterations to find a suitable step size + warmstart: Boolean, default is True + If `warmstart = True` the initial step size at each Armijo iteration is the calculated step size from the last iteration. If `warmstart = False` at each Armijo iteration, the initial step size is reset to the original, large `alpha`. + In the case of *well-behaved* convex functions, `warmstart = True` is likely to be computationally less expensive. In the case of non-convex functions, or particularly tricky functions, setting `warmstart = False` may be beneficial. Reference ------------ @@ -91,14 +97,14 @@ class ArmijoStepSizeRule(StepSizeRule): """ - def __init__(self, alpha=1e6, beta=0.5, max_iterations=None): + def __init__(self, alpha=1e6, beta=0.5, max_iterations=None, warmstart=True): '''Initialises the step size rule ''' self.alpha_orig = alpha if self.alpha_orig is None: # Can be removed when alpha and beta are deprecated in GD self.alpha_orig = 1e6 - + self.alpha = self.alpha_orig self.beta = beta if self.beta is None: # Can be removed when alpha and beta are deprecated in GD self.beta = 0.5 @@ -106,6 +112,8 @@ def __init__(self, alpha=1e6, beta=0.5, max_iterations=None): self.max_iterations = max_iterations if self.max_iterations is None: self.max_iterations = numpy.ceil(2 * numpy.log10(self.alpha_orig) / numpy.log10(2)) + + self.warmstart=warmstart def get_step_size(self, algorithm): """ @@ -117,26 +125,33 @@ def get_step_size(self, algorithm): """ k = 0 - self.alpha = self.alpha_orig - f_x = algorithm.objective_function(algorithm.solution) + if not self.warmstart: + self.alpha = self.alpha_orig + + f_x = algorithm.calculate_objective_function_at_point(algorithm.solution) self.x_armijo = algorithm.solution.copy() - + + log.debug("Starting Armijo backtracking with initial step size: %f", self.alpha) + while k < self.max_iterations: algorithm.gradient_update.multiply(self.alpha, out=self.x_armijo) algorithm.solution.subtract(self.x_armijo, out=self.x_armijo) - f_x_a = algorithm.objective_function(self.x_armijo) + f_x_a = algorithm.calculate_objective_function_at_point(self.x_armijo) sqnorm = algorithm.gradient_update.squared_norm() if f_x_a - f_x <= - (self.alpha/2.) * sqnorm: break k += 1. self.alpha *= self.beta + + log.info("Armijo rule took %d iterations to find step size", k) if k == self.max_iterations: raise ValueError( 'Could not find a proper step_size in {} loops. Consider increasing alpha or max_iterations.'.format(self.max_iterations)) + return self.alpha diff --git a/Wrappers/Python/cil/utilities/dataexample.py b/Wrappers/Python/cil/utilities/dataexample.py index cd0fb93de3..1acdf66411 100644 --- a/Wrappers/Python/cil/utilities/dataexample.py +++ b/Wrappers/Python/cil/utilities/dataexample.py @@ -25,10 +25,9 @@ import os.path import sys from zipfile import ZipFile -from urllib.request import urlopen -from io import BytesIO from scipy.io import loadmat from cil.io import NEXUSDataReader, NikonDataReader, ZEISSDataReader +from zenodo_get import zenodo_get class DATA(object): @classmethod @@ -46,21 +45,15 @@ def get(cls, size=None, scale=(0,1), **kwargs): class REMOTEDATA(DATA): FOLDER = '' - URL = '' - FILE_SIZE = '' + ZENODO_RECORD = '' + ZIP_FILE = '' @classmethod def get(cls, data_dir): return None @classmethod - def _download_and_extract_from_url(cls, data_dir): - with urlopen(cls.URL) as response: - with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile: - zipfile.extractall(path = data_dir) - - @classmethod - def download_data(cls, data_dir): + def download_data(cls, data_dir, prompt=True): ''' Download a dataset from a remote repository @@ -71,14 +64,18 @@ def download_data(cls, data_dir): ''' if os.path.isdir(os.path.join(data_dir, cls.FOLDER)): - print("Dataset already exists in " + data_dir) + print("Dataset folder already exists in " + data_dir) else: - if input("Are you sure you want to download " + cls.FILE_SIZE + " dataset from " + cls.URL + " ? (y/n)") == "y": - print('Downloading dataset from ' + cls.URL) - cls._download_and_extract_from_url(os.path.join(data_dir,cls.FOLDER)) - print('Download complete') - else: + user_input = input("Are you sure you want to download {cls.ZIP_FILE} dataset from Zenodo record {cls.ZENODO_RECORD}? [Y/n]: ") if prompt else 'y' + if user_input.lower() not in ('y', 'yes'): print('Download cancelled') + return False + + zenodo_get([cls.ZENODO_RECORD, '-g', cls.ZIP_FILE, '-o', data_dir]) + with ZipFile(os.path.join(data_dir, cls.ZIP_FILE), 'r') as zip_ref: + zip_ref.extractall(os.path.join(data_dir, cls.FOLDER)) + os.remove(os.path.join(data_dir, cls.ZIP_FILE)) + return True class BOAT(CILDATA): @classmethod @@ -195,15 +192,21 @@ def get(cls, **kwargs): class WALNUT(REMOTEDATA): ''' A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516 + + Example + -------- + >>> data_dir = 'my_PC/data_folder' + >>> dataexample.WALNUT.download_data(data_dir) # download the data + >>> dataexample.WALNUT.get(data_dir) # load the data ''' FOLDER = 'walnut' - URL = 'https://zenodo.org/record/4822516/files/walnut.zip' - FILE_SIZE = '6.4 GB' + ZENODO_RECORD = '4822516' + ZIP_FILE = 'walnut.zip' @classmethod def get(cls, data_dir): ''' - A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516 + Get the microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516 This function returns the raw projection data from the .txrm file Parameters @@ -227,15 +230,21 @@ def get(cls, data_dir): class USB(REMOTEDATA): ''' A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516 + + Example + -------- + >>> data_dir = 'my_PC/data_folder' + >>> dataexample.USB.download_data(data_dir) # download the data + >>> dataexample.USB.get(data_dir) # load the data ''' FOLDER = 'USB' - URL = 'https://zenodo.org/record/4822516/files/usb.zip' - FILE_SIZE = '3.2 GB' + ZENODO_RECORD = '4822516' + ZIP_FILE = 'usb.zip' @classmethod def get(cls, data_dir): ''' - A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516 + Get the microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516 This function returns the raw projection data from the .txrm file Parameters @@ -259,15 +268,21 @@ def get(cls, data_dir): class KORN(REMOTEDATA): ''' A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123 + + Example + -------- + >>> data_dir = 'my_PC/data_folder' + >>> dataexample.KORN.download_data(data_dir) # download the data + >>> dataexample.KORN.get(data_dir) # load the data ''' FOLDER = 'korn' - URL = 'https://zenodo.org/record/6874123/files/korn.zip' - FILE_SIZE = '2.9 GB' + ZENODO_RECORD = '6874123' + ZIP_FILE = 'korn.zip' @classmethod def get(cls, data_dir): ''' - A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123 + Get the microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123 This function returns the raw projection data from the .xtekct file Parameters @@ -279,6 +294,7 @@ def get(cls, data_dir): ------- ImageData The korn dataset + ''' filepath = os.path.join(data_dir, cls.FOLDER, 'Korn i kasse','47209 testscan korn01_recon.xtekct') try: @@ -293,10 +309,40 @@ class SANDSTONE(REMOTEDATA): ''' A synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435 A small subset of the data containing selected projections and 4 slices of the reconstruction + + Example + -------- + >>> data_dir = 'my_PC/data_folder' + >>> dataexample.SANDSTONE.download_data(data_dir) # download the data + >>> dataexample.SANDSTONE.get(data_dir) # load the data ''' FOLDER = 'sandstone' - URL = 'https://zenodo.org/records/4912435/files/small.zip' - FILE_SIZE = '227 MB' + ZENODO_RECORD = '4912435' + ZIP_FILE = 'small.zip' + + @classmethod + def get(cls, data_dir, filename): + ''' + Get the synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435 + A small subset of the data containing selected projections and 4 slices of the reconstruction + Parameters + ---------- + data_dir: str + The path to the directory where the dataset is stored. Data can be downloaded with dataexample.SANDSTONE.download_data(data_dir) + + file: str + The slices or projections to return, specify the path to the file within the data_dir + + Returns + ------- + ImageData + The selected sandstone dataset + ''' + extension = os.path.splitext(filename)[1] + if extension == '.mat': + return loadmat(os.path.join(data_dir,filename)) + raise KeyError(f"Unknown extension: {extension}") + class TestData(object): '''Class to return test data diff --git a/Wrappers/Python/test/test_algorithms.py b/Wrappers/Python/test/test_algorithms.py index e3f6302d78..845b5a44ce 100644 --- a/Wrappers/Python/test/test_algorithms.py +++ b/Wrappers/Python/test/test_algorithms.py @@ -105,13 +105,10 @@ def test_GD(self): alg = GD(initial=initial, objective_function=norm2sq, step_size=step_size, atol=1e-9, rtol=1e-6) - alg.max_iteration = 1000 alg.run(1000,verbose=0) self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) alg = GD(initial=initial, objective_function=norm2sq, step_size=step_size, - atol=1e-9, rtol=1e-6, max_iteration=20, update_objective_interval=2) - alg.max_iteration = 20 - self.assertTrue(alg.max_iteration == 20) + atol=1e-9, rtol=1e-6, update_objective_interval=2) self.assertTrue(alg.update_objective_interval == 2) alg.run(20, verbose=0) self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) @@ -132,7 +129,6 @@ def test_update_interval_0(self): norm2sq = LeastSquares(identity, b) alg = GD(initial=initial, objective_function=norm2sq, - max_iteration=20, update_objective_interval=0, atol=1e-9, rtol=1e-6) self.assertTrue(alg.update_objective_interval == 0) @@ -176,7 +172,6 @@ def test_gd_constant_step_size_init(self): def test_gd_fixed_step_size_rosen(self): gd = GD(initial=self.initial, objective_function=self.f, step_size=0.002, - max_iteration=3000, update_objective_interval=500) gd.run(3000, verbose=0) np.testing.assert_allclose( @@ -224,23 +219,18 @@ def test_GDArmijo(self): norm2sq = LeastSquares(identity, b) alg = GD(initial=initial, objective_function=norm2sq) - alg.max_iteration = 100 alg.run(100, verbose=0) self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) - alg = GD(initial=initial, objective_function=norm2sq, - max_iteration=20, update_objective_interval=2) - # alg.max_iteration = 20 - self.assertTrue(alg.max_iteration == 20) + alg = GD(initial=initial, objective_function=norm2sq, update_objective_interval=2) self.assertTrue(alg.update_objective_interval==2) alg.run(20, verbose=0) self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) def test_gd_armijo_rosen(self): - armj = ArmijoStepSizeRule(alpha=50, max_iterations=150) + armj = ArmijoStepSizeRule(alpha=50, max_iterations=50, warmstart=False) gd = GD(initial=self.initial, objective_function=self.f, step_size=armj, - max_iteration=2500, update_objective_interval=500) - gd.run(2500,verbose=0) + gd.run(3500,verbose=0) np.testing.assert_allclose( gd.solution.array[0], self.scipy_opt_high.x[0], atol=1e-2) np.testing.assert_allclose( @@ -262,31 +252,26 @@ def test_FISTA(self): log.info("initial objective %s", norm2sq(initial)) alg = FISTA(initial=initial, f=norm2sq, g=ZeroFunction()) - alg.max_iteration = 2 alg.run(20, verbose=0) self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) alg = FISTA(initial=initial, f=norm2sq, g=ZeroFunction(), - max_iteration=2, update_objective_interval=2) + update_objective_interval=2) - self.assertTrue(alg.max_iteration == 2) self.assertTrue(alg.update_objective_interval == 2) alg.run(20, verbose=0) self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) # Testing g=None - alg = FISTA(initial=initial, f=norm2sq, g=None, - max_iteration=2, update_objective_interval=2) - self.assertTrue(alg.max_iteration == 2) + alg = FISTA(initial=initial, f=norm2sq, g=None, update_objective_interval=2) self.assertTrue(alg.update_objective_interval == 2) alg.run(20, verbose=0) self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) # Testing f=None alg = FISTA(initial=initial, f=None, g=L1Norm(b=b), - max_iteration=2, update_objective_interval=2) - self.assertTrue(alg.max_iteration == 2) + update_objective_interval=2) self.assertTrue(alg.update_objective_interval == 2) alg.run(20, verbose=0) self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) @@ -294,7 +279,7 @@ def test_FISTA(self): # Testing f and g is None with self.assertRaises(ValueError): alg = FISTA(initial=initial, f=None, g=None, - max_iteration=2, update_objective_interval=2) + update_objective_interval=2) def test_FISTA_update(self): @@ -319,7 +304,7 @@ def test_FISTA_update(self): # ista run 10 iteration tmp_initial = ig.allocate() - fista = FISTA(initial=tmp_initial, f=f, g=g, max_iteration=1) + fista = FISTA(initial=tmp_initial, f=f, g=g) fista.run(1) # fista update method @@ -348,11 +333,11 @@ def test_FISTA_update(self): self.assertTrue(res1 == res2) tmp_initial = ig.allocate() - fista1 = FISTA(initial=tmp_initial, f=f, g=g, max_iteration=1) + fista1 = FISTA(initial=tmp_initial, f=f, g=g) self.assertTrue(fista1.is_provably_convergent()) fista1 = FISTA(initial=tmp_initial, f=f, g=g, - max_iteration=1, step_size=30.0) + step_size=30.0) self.assertFalse(fista1.is_provably_convergent()) def test_FISTA_Norm2Sq(self): @@ -367,13 +352,11 @@ def test_FISTA_Norm2Sq(self): opt = {'tol': 1e-4, 'memopt': False} log.info("initial objective %s", norm2sq(initial)) alg = FISTA(initial=initial, f=norm2sq, g=ZeroFunction()) - alg.max_iteration = 2 alg.run(20, verbose=0) self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) alg = FISTA(initial=initial, f=norm2sq, g=ZeroFunction(), - max_iteration=2, update_objective_interval=3) - self.assertTrue(alg.max_iteration == 2) + update_objective_interval=3) self.assertTrue(alg.update_objective_interval == 3) alg.run(20, verbose=0) @@ -419,7 +402,6 @@ def test_FISTA_Denoising(self): initial = ig.allocate() fista = FISTA(initial=initial, f=reg, g=fid) - fista.max_iteration = 3000 fista.update_objective_interval = 500 fista.run(3000, verbose=0) rmse = (fista.get_output() - data).norm() / data.as_array().size @@ -483,7 +465,7 @@ def test_update(self): # ista run 10 iteration tmp_initial = self.ig.allocate() - ista = ISTA(initial=tmp_initial, f=self.f, g=self.g, max_iteration=1) + ista = ISTA(initial=tmp_initial, f=self.f, g=self.g) ista.run(1) x = tmp_initial.copy() @@ -505,7 +487,7 @@ def test_update_g_none(self): # ista run 10 iteration tmp_initial = self.ig.allocate() - ista = ISTA(initial=tmp_initial, f=self.f, g=None, max_iteration=1) + ista = ISTA(initial=tmp_initial, f=self.f, g=None) ista.run(1) x = tmp_initial.copy() @@ -526,7 +508,7 @@ def test_update_f_none(self): # ista run 1 iteration tmp_initial = self.ig.allocate() - ista = ISTA(initial=tmp_initial, f=None, g=self.h, max_iteration=1) + ista = ISTA(initial=tmp_initial, f=None, g=self.h) ista.run(1) x = tmp_initial.copy() @@ -546,23 +528,23 @@ def test_update_f_none(self): def test_f_and_g_none(self): tmp_initial = self.ig.allocate() with self.assertRaises(ValueError): - ista = ISTA(initial=tmp_initial, f=None, g=None, max_iteration=1) + ista = ISTA(initial=tmp_initial, f=None, g=None) def test_provable_condition(self): tmp_initial = self.ig.allocate() - ista1 = ISTA(initial=tmp_initial, f=self.f, g=self.g, max_iteration=10) + ista1 = ISTA(initial=tmp_initial, f=self.f, g=self.g) self.assertTrue(ista1.is_provably_convergent()) ista1 = ISTA(initial=tmp_initial, f=self.f, g=self.g, - max_iteration=10, step_size=30.0) + step_size=30.0) self.assertFalse(ista1.is_provably_convergent()) @unittest.skipUnless(has_cvxpy, "CVXpy not installed") def test_with_cvxpy(self): ista = ISTA(initial=self.initial, f=self.f, - g=self.g, max_iteration=2000) + g=self.g) ista.run(2000, verbose=0) u_cvxpy = cvxpy.Variable(self.ig.shape[0]) @@ -738,7 +720,6 @@ def setup(data, dnoise): # Setup and run the PDHG algorithm pdhg1 = PDHG(f=f1, g=g, operator=operator, tau=tau, sigma=sigma) - pdhg1.max_iteration = 2000 pdhg1.update_objective_interval = 200 pdhg1.run(1000, verbose=0) @@ -763,7 +744,7 @@ def setup(data, dnoise): # Setup and run the PDHG algorithm pdhg1 = PDHG(f=f1, g=g, operator=operator, tau=tau, sigma=sigma, - max_iteration=2000, update_objective_interval=200) + update_objective_interval=200) pdhg1.run(1000, verbose=0) @@ -788,7 +769,6 @@ def setup(data, dnoise): # Setup and run the PDHG algorithm pdhg1 = PDHG(f=f1, g=g, operator=operator, tau=tau, sigma=sigma) - pdhg1.max_iteration = 2000 pdhg1.update_objective_interval = 200 pdhg1.run(1000, verbose=0) @@ -805,28 +785,28 @@ def test_PDHG_step_sizes(self): operator = 3*IdentityOperator(ig) # check if sigma, tau are None - pdhg = PDHG(f=f, g=g, operator=operator, max_iteration=10) + pdhg = PDHG(f=f, g=g, operator=operator) self.assertAlmostEqual(pdhg.sigma, 1./operator.norm()) self.assertAlmostEqual(pdhg.tau, 1./operator.norm()) # check if sigma is negative with self.assertRaises(ValueError): pdhg = PDHG(f=f, g=g, operator=operator, - max_iteration=10, sigma=-1) + sigma=-1) # check if tau is negative with self.assertRaises(ValueError): - pdhg = PDHG(f=f, g=g, operator=operator, max_iteration=10, tau=-1) + pdhg = PDHG(f=f, g=g, operator=operator,tau=-1) # check if tau is None sigma = 3.0 - pdhg = PDHG(f=f, g=g, operator=operator, sigma=sigma, max_iteration=10) + pdhg = PDHG(f=f, g=g, operator=operator, sigma=sigma) self.assertAlmostEqual(pdhg.sigma, sigma) self.assertAlmostEqual(pdhg.tau, 1./(sigma * operator.norm()**2)) # check if sigma is None tau = 3.0 - pdhg = PDHG(f=f, g=g, operator=operator, tau=tau, max_iteration=10) + pdhg = PDHG(f=f, g=g, operator=operator, tau=tau) self.assertAlmostEqual(pdhg.tau, tau) self.assertAlmostEqual(pdhg.sigma, 1./(tau * operator.norm()**2)) @@ -834,7 +814,7 @@ def test_PDHG_step_sizes(self): tau = 1.0 sigma = 1.0 pdhg = PDHG(f=f, g=g, operator=operator, tau=tau, - sigma=sigma, max_iteration=10) + sigma=sigma) self.assertAlmostEqual(pdhg.tau, tau) self.assertAlmostEqual(pdhg.sigma, sigma) @@ -843,29 +823,29 @@ def test_PDHG_step_sizes(self): sigma = ig1.allocate() with self.assertRaises(ValueError): pdhg = PDHG(f=f, g=g, operator=operator, - sigma=sigma, max_iteration=10) + sigma=sigma) # check sigma/tau as arrays, tau wrong shape tau = ig1.allocate() with self.assertRaises(ValueError): - pdhg = PDHG(f=f, g=g, operator=operator, tau=tau, max_iteration=10) + pdhg = PDHG(f=f, g=g, operator=operator, tau=tau) # check sigma not Number or object with correct shape with self.assertRaises(AttributeError): pdhg = PDHG(f=f, g=g, operator=operator, - sigma="sigma", max_iteration=10) + sigma="sigma") # check tau not Number or object with correct shape with self.assertRaises(AttributeError): pdhg = PDHG(f=f, g=g, operator=operator, - tau="tau", max_iteration=10) + tau="tau") # check warning message if condition is not satisfied sigma = 4 tau = 1/3 with self.assertWarnsRegex(UserWarning, "Convergence criterion"): pdhg = PDHG(f=f, g=g, operator=operator, tau=tau, - sigma=sigma, max_iteration=10) + sigma=sigma) def test_PDHG_strongly_convex_gamma_g(self): ig = ImageGeometry(3, 3) @@ -880,7 +860,7 @@ def test_PDHG_strongly_convex_gamma_g(self): tau = 1.0 pdhg = PDHG(f=f, g=g, operator=operator, sigma=sigma, tau=tau, - max_iteration=5, gamma_g=0.5) + gamma_g=0.5) pdhg.run(1, verbose=0) self.assertAlmostEqual( pdhg.theta, 1.0 / np.sqrt(1 + 2 * pdhg.gamma_g * tau)) @@ -893,12 +873,12 @@ def test_PDHG_strongly_convex_gamma_g(self): # check negative strongly convex constant with self.assertRaises(ValueError): pdhg = PDHG(f=f, g=g, operator=operator, sigma=sigma, tau=tau, - max_iteration=5, gamma_g=-0.5) + gamma_g=-0.5) # check strongly convex constant not a number with self.assertRaises(ValueError): pdhg = PDHG(f=f, g=g, operator=operator, sigma=sigma, tau=tau, - max_iteration=5, gamma_g="-0.5") + gamma_g="-0.5") def test_PDHG_strongly_convex_gamma_fcong(self): ig = ImageGeometry(3, 3) @@ -913,7 +893,7 @@ def test_PDHG_strongly_convex_gamma_fcong(self): tau = 1.0 pdhg = PDHG(f=f, g=g, operator=operator, sigma=sigma, tau=tau, - max_iteration=5, gamma_fconj=0.5) + gamma_fconj=0.5) pdhg.run(1, verbose=0) self.assertEqual(pdhg.theta, 1.0 / np.sqrt(1 + 2 * pdhg.gamma_fconj * sigma)) @@ -926,14 +906,14 @@ def test_PDHG_strongly_convex_gamma_fcong(self): # check negative strongly convex constant try: pdhg = PDHG(f=f, g=g, operator=operator, sigma=sigma, tau=tau, - max_iteration=5, gamma_fconj=-0.5) + gamma_fconj=-0.5) except ValueError as ve: log.info(str(ve)) # check strongly convex constant not a number try: pdhg = PDHG(f=f, g=g, operator=operator, sigma=sigma, tau=tau, - max_iteration=5, gamma_fconj="-0.5") + gamma_fconj="-0.5") except ValueError as ve: log.info(str(ve)) @@ -947,7 +927,7 @@ def test_PDHG_strongly_convex_both_fconj_and_g(self): operator = IdentityOperator(ig) try: - pdhg = PDHG(f=f, g=g, operator=operator, max_iteration=10, + pdhg = PDHG(f=f, g=g, operator=operator, gamma_g=0.5, gamma_fconj=0.5) pdhg.run(verbose=0) except ValueError as err: @@ -1031,7 +1011,7 @@ def test_update(self): # sirt run 5 iterations tmp_initial = self.ig.allocate() sirt = SIRT(initial=tmp_initial, operator=self.Aop, - data=self.bop, max_iteration=5) + data=self.bop) sirt.run(5) x = tmp_initial.copy() @@ -1046,22 +1026,22 @@ def test_update(self): def test_update_constraints(self): alg = SIRT(initial=self.initial2, operator=self.A2, - data=self.b2, max_iteration=20) + data=self.b2) alg.run(20,verbose=0) np.testing.assert_array_almost_equal(alg.x.array, self.b2.array) alg = SIRT(initial=self.initial2, operator=self.A2, - data=self.b2, max_iteration=20, upper=0.3) + data=self.b2, upper=0.3) alg.run(20,verbose=0) np.testing.assert_almost_equal(alg.solution.max(), 0.3) alg = SIRT(initial=self.initial2, operator=self.A2, - data=self.b2, max_iteration=20, lower=0.7) + data=self.b2, lower=0.7) alg.run(20,verbose=0) np.testing.assert_almost_equal(alg.solution.min(), 0.7) alg = SIRT(initial=self.initial2, operator=self.A2, data=self.b2, - max_iteration=20, constraint=IndicatorBox(lower=0.1, upper=0.3)) + constraint=IndicatorBox(lower=0.1, upper=0.3)) alg.run(20,verbose=0) np.testing.assert_almost_equal(alg.solution.max(), 0.3) np.testing.assert_almost_equal(alg.solution.min(), 0.1) @@ -1069,7 +1049,7 @@ def test_update_constraints(self): def test_SIRT_relaxation_parameter(self): tmp_initial = self.ig.allocate() alg = SIRT(initial=tmp_initial, operator=self.Aop, - data=self.bop, max_iteration=5) + data=self.bop) with self.assertRaises(ValueError): alg.set_relaxation_parameter(0) @@ -1078,7 +1058,7 @@ def test_SIRT_relaxation_parameter(self): alg.set_relaxation_parameter(2) alg = SIRT(initial=self.initial2, operator=self.A2, - data=self.b2, max_iteration=20) + data=self.b2) alg.set_relaxation_parameter(0.5) self.assertEqual(alg.relaxation_parameter, 0.5) @@ -1095,7 +1075,7 @@ def test_SIRT_nan_inf_values(self): tmp_initial = self.ig.allocate() sirt = SIRT(initial=tmp_initial, operator=Aop_nan_inf, - data=self.bop, max_iteration=5) + data=self.bop) self.assertFalse(np.any(sirt.M == np.inf)) self.assertFalse(np.any(sirt.D == np.inf)) @@ -1117,7 +1097,7 @@ def test_SIRT_remove_nan_or_inf_with_BlockDataContainer(self): tmp_initial = ig.allocate() sirt = SIRT(initial=tmp_initial, operator=Aop, - data=bop, max_iteration=5) + data=bop) for el in sirt.M.containers: self.assertFalse(np.any(el == np.inf)) @@ -1127,13 +1107,13 @@ def test_SIRT_with_TV(self): data = dataexample.SIMPLE_PHANTOM_2D.get(size=(128, 128)) ig = data.geometry A = IdentityOperator(ig) - constraint = TotalVariation(warm_start=False, max_iteration=100) + constraint = TotalVariation(warm_start=False) initial = ig.allocate('random', seed=5) sirt = SIRT(initial=initial, operator=A, data=data, - max_iteration=2, constraint=constraint) + constraint=constraint) sirt.run(2, verbose=0) f = LeastSquares(A, data, c=0.5) - fista = FISTA(initial=initial, f=f, g=constraint, max_iteration=1000) + fista = FISTA(initial=initial, f=f, g=constraint) fista.run(100, verbose=0) self.assertNumpyArrayAlmostEqual(fista.x.as_array(), sirt.x.as_array()) @@ -1144,7 +1124,7 @@ def test_SIRT_with_TV_warm_start(self): constraint = 1e6*TotalVariation(warm_start=True, max_iteration=100) initial = ig.allocate('random', seed=5) sirt = SIRT(initial=initial, operator=A, data=data, - max_iteration=150, constraint=constraint) + constraint=constraint) sirt.run(25, verbose=0) self.assertNumpyArrayAlmostEqual( @@ -1470,11 +1450,11 @@ def do_test_with_fidelity(self, fidelity): F = self.F admm = LADMM(f=G, g=F, operator=K, tau=self.tau, sigma=self.sigma, - max_iteration=100, update_objective_interval=10) + update_objective_interval=10) admm.run(1, verbose=0) admm_noaxpby = LADMM(f=G, g=F, operator=K, tau=self.tau, sigma=self.sigma, - max_iteration=100, update_objective_interval=10) + update_objective_interval=10) admm_noaxpby.run(1, verbose=0) np.testing.assert_array_almost_equal( admm.solution.as_array(), admm_noaxpby.solution.as_array()) @@ -1506,14 +1486,14 @@ def test_compare_with_PDHG(self): tau = 1./normK pdhg = PDHG(f=F, g=G, operator=K, tau=tau, sigma=sigma, - max_iteration=500, update_objective_interval=10) + update_objective_interval=10) pdhg.run(500,verbose=0) sigma = 1 tau = sigma/normK**2 admm = LADMM(f=G, g=F, operator=K, tau=tau, sigma=sigma, - max_iteration=500, update_objective_interval=10) + update_objective_interval=10) admm.run(500,verbose=0) np.testing.assert_almost_equal( admm.solution.array, pdhg.solution.array, decimal=3) @@ -1564,8 +1544,7 @@ def test_PD3O_PDHG_denoising_1_iteration(self): G = 0.5 * L2NormSquared(b=self.data) sigma = 1./norm_op tau = 1./norm_op - pdhg = PDHG(f=F, g=G, operator=operator, tau=tau, sigma=sigma, update_objective_interval = 100, - max_iteration = 2000) + pdhg = PDHG(f=F, g=G, operator=operator, tau=tau, sigma=sigma, update_objective_interval = 100) pdhg.run(1) # setup PD3O denoising (F=ZeroFunction) @@ -1576,8 +1555,7 @@ def test_PD3O_PDHG_denoising_1_iteration(self): delta = 1./norm_op pd3O = PD3O(f=F, g=G, h=H, operator=operator, gamma=gamma, delta=delta, - update_objective_interval = 100, - max_iteration = 2000) + update_objective_interval = 100) pd3O.run(1) # PD3O vs pdhg diff --git a/Wrappers/Python/test/test_dataexample.py b/Wrappers/Python/test/test_dataexample.py index b1b2acaa52..4ace99a58f 100644 --- a/Wrappers/Python/test/test_dataexample.py +++ b/Wrappers/Python/test/test_dataexample.py @@ -25,11 +25,11 @@ from testclass import CCPiTestClass import platform import numpy as np -from unittest.mock import patch, MagicMock -from urllib import request +from unittest.mock import patch from zipfile import ZipFile from io import StringIO -from tempfile import NamedTemporaryFile +import uuid +from zenodo_get import zenodo_get initialise_tests() @@ -157,116 +157,89 @@ def test_load_SIMULATED_CONE_BEAM_DATA(self): class TestRemoteData(unittest.TestCase): def setUp(self): - self.data_list = ['WALNUT','USB','KORN','SANDSTONE'] - self.shapes_path = os.path.join(dataexample.CILDATA.data_dir, dataexample.TestData.SHAPES) - def mock_urlopen(self, mock_urlopen, zipped_bytes): - mock_response = MagicMock() - mock_response.read.return_value = zipped_bytes - mock_response.__enter__.return_value = mock_response - mock_urlopen.return_value = mock_response - @unittest.skipIf(platform.system() == 'Windows', "Skip on Windows") - @patch('cil.utilities.dataexample.urlopen') - def test_unzip_remote_data(self, mock_urlopen): - ''' - Test the _download_and_extract_data_from_url function correctly extracts files from a byte string - The zipped byte string is mocked using a temporary local zip file + def mock_zenodo_get(*args): + # mock zenodo_get by making a zip file containing the shapes test data when the function is called + shapes_path = os.path.join(dataexample.CILDATA.data_dir, dataexample.TestData.SHAPES) + with ZipFile(os.path.join(args[0][4], args[0][2]), mode='w') as zip_file: + zip_file.write(shapes_path, arcname=dataexample.TestData.SHAPES) + + + @patch('cil.utilities.dataexample.input', return_value='y') + @patch('cil.utilities.dataexample.zenodo_get', side_effect=mock_zenodo_get) + def test_download_data_input_y(self, mock_zenodo_get, input): ''' - - # create a temporary zip file to test the function - with NamedTemporaryFile(suffix = '.zip') as tf: - tmp_path = os.path.dirname(tf.name) - tmp_dir = os.path.splitext(os.path.basename(tf.name))[0] - with ZipFile(tf.name, mode='w') as zip_file: - zip_file.write(self.shapes_path, arcname=dataexample.TestData.SHAPES) - - with open(tf.name, 'rb') as zip_file: - zipped_bytes = zip_file.read() - - self.mock_urlopen(mock_urlopen, zipped_bytes) - dataexample.REMOTEDATA._download_and_extract_from_url(os.path.join(tmp_path, tmp_dir)) + Test the download_data function, when the user input is 'y' to 'are you sure you want to download data' + The user input to confirm the download is mocked as 'y' + The zip file download is mocked by creating a zip file locally + Test the download_data function correctly extracts files from the zip file + ''' + # create a temporary folder in the CIL data directory + tmp_dir = os.path.join(dataexample.CILDATA.data_dir, str(uuid.uuid4())) + os.makedirs(tmp_dir) + # redirect print output + capturedOutput = StringIO() + sys.stdout = capturedOutput + for data in self.data_list: + test_func = getattr(dataexample, data) + test_func.download_data(tmp_dir) + # Test the data file exists + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, getattr(test_func, 'FOLDER'), dataexample.TestData.SHAPES)), + msg = "Download data test failed with dataset " + data) + # Test the zip file is removed + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, getattr(test_func, 'ZIP_FILE')))) + # return to standard print output + sys.stdout = sys.__stdout__ + shutil.rmtree(tmp_dir) - self.assertTrue(os.path.isfile(os.path.join(tmp_path, tmp_dir, dataexample.TestData.SHAPES))) - if os.path.exists(os.path.join(tmp_path,tmp_dir)): - shutil.rmtree(os.path.join(tmp_path,tmp_dir)) - - @unittest.skipIf(platform.system() == 'Windows', "Skip on Windows") - @patch('cil.utilities.dataexample.input', return_value='n') - @patch('cil.utilities.dataexample.urlopen') - def test_download_data_input_n(self, mock_urlopen, input): + @patch('cil.utilities.dataexample.input', return_value='n') + @patch('cil.utilities.dataexample.zenodo_get', side_effect=mock_zenodo_get) + def test_download_data_input_n(self, mock_zenodo_get, input): ''' Test the download_data function, when the user input is 'n' to 'are you sure you want to download data' - The zipped byte string is mocked using a temporary local zip file ''' - - # create a temporary zip file to test the function - with NamedTemporaryFile(suffix = '.zip') as tf: - tmp_path = os.path.dirname(tf.name) - tmp_dir = os.path.splitext(os.path.basename(tf.name))[0] - with ZipFile(tf.name, mode='w') as zip_file: - zip_file.write(self.shapes_path, arcname=dataexample.TestData.SHAPES) - - with open(tf.name, 'rb') as zip_file: - zipped_bytes = zip_file.read() - - self.mock_urlopen(mock_urlopen, zipped_bytes) - + # create a temporary folder in the CIL data directory + tmp_dir = os.path.join(dataexample.CILDATA.data_dir, str(uuid.uuid4())) + os.makedirs(tmp_dir) for data in self.data_list: # redirect print output - capturedOutput = StringIO() - sys.stdout = capturedOutput + capturedOutput = StringIO() + sys.stdout = capturedOutput test_func = getattr(dataexample, data) - test_func.download_data(os.path.join(tmp_path, tmp_dir)) - self.assertFalse(os.path.isfile(os.path.join(tmp_path, tmp_dir, test_func.FOLDER, dataexample.TestData.SHAPES)), msg = "Failed with dataset " + data) - self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n', msg = "Failed with dataset " + data) + test_func.download_data(tmp_dir) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, getattr(test_func, 'FOLDER'), dataexample.TestData.SHAPES)), + msg = "Download dataset test failed with dataset " + data) + self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n', + msg = "Download dataset test failed with dataset " + data) # return to standard print output sys.stdout = sys.__stdout__ - if os.path.exists(os.path.join(tmp_path,tmp_dir)): - shutil.rmtree(os.path.join(tmp_path,tmp_dir)) - - @unittest.skipIf(platform.system() == 'Windows', "Skip on Windows") - @patch('cil.utilities.dataexample.input', return_value='y') - @patch('cil.utilities.dataexample.urlopen') - def test_download_data_input_y(self, mock_urlopen, input): - ''' - Test the download_data function, when the user input is 'y' to 'are you sure you want to download data' - The zipped byte string is mocked using a temporary local zip file - ''' - - with NamedTemporaryFile(suffix = '.zip') as tf: - tmp_path = os.path.dirname(tf.name) - tmp_dir = os.path.splitext(os.path.basename(tf.name))[0] - with ZipFile(tf.name, mode='w') as zip_file: - zip_file.write(self.shapes_path, arcname=dataexample.TestData.SHAPES) - - with open(tf.name, 'rb') as zip_file: - zipped_bytes = zip_file.read() - - self.mock_urlopen(mock_urlopen, zipped_bytes) - - # redirect print output - capturedOutput = StringIO() - sys.stdout = capturedOutput + # Test the zip file IS created with prompt=False i.e. prompt not used + dataexample.WALNUT.download_data(tmp_dir, prompt=False) + # Test the data file exists + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, dataexample.WALNUT.FOLDER, dataexample.TestData.SHAPES)), + msg = "Download data test failed with dataset " + data) + # Test the zip file is removed + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, dataexample.WALNUT.ZIP_FILE))) - for data in self.data_list: - test_func = getattr(dataexample, data) - test_func.download_data(os.path.join(tmp_path, tmp_dir)) - self.assertTrue(os.path.isfile(os.path.join(tmp_path, tmp_dir, test_func.FOLDER, dataexample.TestData.SHAPES)), msg = "Failed with dataset " + data) - - # return to standard print output - sys.stdout = sys.__stdout__ - - if os.path.exists(os.path.join(tmp_path,tmp_dir)): - shutil.rmtree(os.path.join(tmp_path,tmp_dir)) + shutil.rmtree(tmp_dir) - def test_download_data_bad_URL(self): + @patch('cil.utilities.dataexample.input', return_value='y') + def test_download_data_empty(self, input): ''' - Test an error is raised when _download_and_extract_from_url has an empty URL + Test an error is raised when download_data is used on an empty Zenodo record ''' + remote_data = dataexample.REMOTEDATA + remote_data.ZENODO_RECORD = 'empty' + remote_data.FOLDER = 'empty' + with self.assertRaises(ValueError): - dataexample.REMOTEDATA._download_and_extract_from_url('.') + remote_data.download_data('.') + + def test_a(self): + from cil.utilities.dataexample import WALNUT + \ No newline at end of file diff --git a/Wrappers/Python/test/test_stepsizes.py b/Wrappers/Python/test/test_stepsizes.py index 860da40896..1ef0fc145c 100644 --- a/Wrappers/Python/test/test_stepsizes.py +++ b/Wrappers/Python/test/test_stepsizes.py @@ -21,7 +21,7 @@ def test_step_sizes_called(self): step_size_test.get_step_size = MagicMock(return_value=.1) f = LeastSquares(A=A, b=data, c=0.5) alg = GD(initial=ig.allocate('random', seed=10), objective_function=f, step_size=step_size_test, - max_iteration=100, update_objective_interval=1) + update_objective_interval=1) alg.run(5) @@ -30,54 +30,115 @@ def test_step_sizes_called(self): step_size_test = ConstantStepSize(3) step_size_test.get_step_size = MagicMock(return_value=.1) alg = ISTA(initial=ig.allocate('random', seed=10), f=f, g=IndicatorBox(lower=0), step_size=step_size_test, - max_iteration=100, update_objective_interval=1) + update_objective_interval=1) alg.run(5) self.assertEqual(len(step_size_test.get_step_size.mock_calls), 5) step_size_test = ConstantStepSize(3) step_size_test.get_step_size = MagicMock(return_value=.1) alg = FISTA(initial=ig.allocate('random', seed=10), f=f, g=IndicatorBox(lower=0), step_size=step_size_test, - max_iteration=100, update_objective_interval=1) + update_objective_interval=1) alg.run(5) self.assertEqual(len(step_size_test.get_step_size.mock_calls), 5) +class TestStepSizeConstant(CCPiTestClass): def test_constant(self): test_stepsize = ConstantStepSize(0.3) self.assertEqual(test_stepsize.step_size, 0.3) +class TestStepSizeArmijo(CCPiTestClass): + + def setUp(self): + self.ig = VectorGeometry(2) + self.data = self.ig.allocate('random') + self.data.fill(np.array([3.5, 3.5])) + self.A = MatrixOperator(np.diag([1., 1.])) + self.f = LeastSquares(self.A, self.data) + + def test_armijo_init(self): - test_stepsize = ArmijoStepSizeRule(alpha=1e3, beta=0.4, max_iterations=40) + test_stepsize = ArmijoStepSizeRule(alpha=1e3, beta=0.4, max_iterations=40, warmstart=False) + self.assertFalse(test_stepsize.warmstart) self.assertEqual(test_stepsize.alpha_orig, 1e3) self.assertEqual(test_stepsize.beta, 0.4) self.assertEqual(test_stepsize.max_iterations, 40) test_stepsize = ArmijoStepSizeRule() + self.assertTrue(test_stepsize.warmstart) self.assertEqual(test_stepsize.alpha_orig, 1e6) self.assertEqual(test_stepsize.beta, 0.5) self.assertEqual(test_stepsize.max_iterations, np.ceil( 2 * np.log10(1e6) / np.log10(2))) def test_armijo_calculation(self): - test_stepsize = ArmijoStepSizeRule(alpha=8, beta=0.5, max_iterations=100) - ig = VectorGeometry(2) - data = ig.allocate('random') - data.fill(np.array([3.5, 3.5])) - A = MatrixOperator(np.diag([1., 1.])) - f = LeastSquares(A, data) - alg = GD(initial=ig.allocate(0), objective_function=f, - max_iteration=100, update_objective_interval=1, step_size=test_stepsize) - alg.gradient_update = ig.allocate(-1) + test_stepsize = ArmijoStepSizeRule(alpha=8, beta=0.5, max_iterations=100, warmstart=False) + + alg = GD(initial=self.ig.allocate(0), objective_function=self.f, + update_objective_interval=1, step_size=test_stepsize) + alg.gradient_update = self.ig.allocate(-1) + step_size = test_stepsize.get_step_size(alg) + self.assertAlmostEqual(step_size, 4) + + alg.gradient_update = self.ig.allocate(-.5) + step_size = test_stepsize.get_step_size(alg) + self.assertAlmostEqual(step_size, 8) + + alg.gradient_update = self.ig.allocate(-2) + step_size = test_stepsize.get_step_size(alg) + self.assertAlmostEqual(step_size, 2) + + def test_armijo_ISTA_and_FISTA(self): + test_stepsize = ArmijoStepSizeRule(alpha=8, beta=0.5, max_iterations=100, warmstart=False) + + alg = ISTA(initial=self.ig.allocate(0), f=self.f, g=IndicatorBox(lower=0), + update_objective_interval=1, step_size=test_stepsize) + alg.gradient_update = self.ig.allocate(-1) step_size = test_stepsize.get_step_size(alg) self.assertAlmostEqual(step_size, 4) - alg.gradient_update = ig.allocate(-.5) + alg.gradient_update = self.ig.allocate(-.5) step_size = test_stepsize.get_step_size(alg) self.assertAlmostEqual(step_size, 8) - alg.gradient_update = ig.allocate(-2) + alg.gradient_update = self.ig.allocate(-2) step_size = test_stepsize.get_step_size(alg) self.assertAlmostEqual(step_size, 2) + alg = FISTA(initial=self.ig.allocate(0), f=self.f, g=IndicatorBox(lower=0), + update_objective_interval=1, step_size=test_stepsize) + alg.gradient_update = self.ig.allocate(-1) + step_size = test_stepsize.get_step_size(alg) + self.assertAlmostEqual(step_size, 4) + + alg.gradient_update = self.ig.allocate(-.5) + step_size = test_stepsize.get_step_size(alg) + self.assertAlmostEqual(step_size, 8) + + alg.gradient_update = self.ig.allocate(-2) + step_size = test_stepsize.get_step_size(alg) + self.assertAlmostEqual(step_size, 2) + + def test_warmstart_true(self): + + rule = ArmijoStepSizeRule(warmstart=True, alpha=5000) + self.assertTrue(rule.warmstart) + self.assertTrue(rule.alpha_orig == 5000) + alg = GD(initial=self.ig.allocate(0), objective_function=self.f, + update_objective_interval=1, step_size=rule) + alg.update() + self.assertFalse(rule.alpha == 5000) + + def test_warmstart_false(self): + rule = ArmijoStepSizeRule(warmstart=False, alpha=5000) + self.assertFalse(rule.warmstart) + self.assertTrue(rule.alpha_orig == 5000) + alg = GD(initial=self.ig.allocate(0), objective_function=self.f, + update_objective_interval=1, step_size=rule) + alg.update() + self.assertTrue(rule.alpha_orig == 5000) + self.assertFalse(rule.alpha_orig == rule.alpha) + +class TestStepSizeBB(CCPiTestClass): def test_bb(self): n = 10 m = 5 @@ -218,7 +279,7 @@ def test_bb_converge(self): initial = ig.allocate() f = LeastSquares(Aop, b=bop, c=2) - ss_rule=ArmijoStepSizeRule(max_iterations=40) + ss_rule=ArmijoStepSizeRule(max_iterations=40, warmstart=False) alg_true = GD(initial=initial, objective_function=f, step_size=ss_rule) alg_true .run(300, verbose=0) @@ -243,3 +304,4 @@ def test_bb_converge(self): self.assertNumpyArrayAlmostEqual(alg.x.as_array(), alg_true.x.as_array(), decimal=3) + diff --git a/docs/source/utilities.rst b/docs/source/utilities.rst index dad9d4966e..816d491b4b 100644 --- a/docs/source/utilities.rst +++ b/docs/source/utilities.rst @@ -53,6 +53,40 @@ Simulated image data :members: :inherited-members: +Remote data +----------- +Remote data classes can be used to access specific datasets from zenodo. These +datasets are not packaged as part of CIL, instead the `download_data(data_dir)` +method can be used to download the dataset to a chosen data directory then loaded +from that data directory using `get(data_dir)`. + +Walnut +------ + +.. autoclass:: cil.utilities.dataexample.WALNUT + :members: + :inherited-members: + +USB +------ + +.. autoclass:: cil.utilities.dataexample.USB + :members: + :inherited-members: + +KORN +------ + +.. autoclass:: cil.utilities.dataexample.KORN + :members: + :inherited-members: + +SANDSTONE +------ +.. autoclass:: cil.utilities.dataexample.SANDSTONE + :members: + :inherited-members: + Image Quality metrics diff --git a/recipe/meta.yaml b/recipe/meta.yaml index 3355fee710..26b339bfad 100644 --- a/recipe/meta.yaml +++ b/recipe/meta.yaml @@ -70,6 +70,7 @@ requirements: - ipp >=2021.10 - tqdm - numba + - zenodo_get >=1.6 #optional packages with version dependancies run_constrained: diff --git a/scripts/create_local_env_for_cil_development.sh b/scripts/create_local_env_for_cil_development.sh index 8e65ca665f..d7564dc7dc 100755 --- a/scripts/create_local_env_for_cil_development.sh +++ b/scripts/create_local_env_for_cil_development.sh @@ -64,6 +64,7 @@ conda_args=(create --name="$name" scikit-image scipy tqdm + zenodo_get'>=1.6' ) if test -n "$cil_ver"; then echo "CIL version $cil_ver" diff --git a/scripts/requirements-test.yml b/scripts/requirements-test.yml index 73e121d81d..17668b73d8 100644 --- a/scripts/requirements-test.yml +++ b/scripts/requirements-test.yml @@ -48,3 +48,4 @@ dependencies: - pywavelets - numba - tqdm + - zenodo_get >=1.6