diff --git a/CHANGELOG.md b/CHANGELOG.md index 0eb818e67e..4afe8479ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,7 @@ - BlockOperator that would return a BlockDataContainer of shape (1,1) now returns the appropriate DataContainer. BlockDataContainer direct and adjoint methods accept DataContainer as parameter (#1802). - BlurringOperator: remove check for geometry class (old SIRF integration bug) (#1807) - The `ZeroFunction` and `ConstantFunction` now have a Lipschitz constant of 1. (#1768) + - Update dataexample remote data download to work with windows and use zenodo_get for data download (#1774) - Changes that break backwards compatibility: - Merged the files `BlockGeometry.py` and `BlockDataContainer.py` in `framework` to one file `block.py`. Please use `from cil.framework import BlockGeometry, BlockDataContainer` as before (#1799) - Bug fix in `FGP_TV` function to set the default behaviour not to enforce non-negativity (#1826). diff --git a/Wrappers/Python/cil/utilities/dataexample.py b/Wrappers/Python/cil/utilities/dataexample.py index cd0fb93de3..1acdf66411 100644 --- a/Wrappers/Python/cil/utilities/dataexample.py +++ b/Wrappers/Python/cil/utilities/dataexample.py @@ -25,10 +25,9 @@ import os.path import sys from zipfile import ZipFile -from urllib.request import urlopen -from io import BytesIO from scipy.io import loadmat from cil.io import NEXUSDataReader, NikonDataReader, ZEISSDataReader +from zenodo_get import zenodo_get class DATA(object): @classmethod @@ -46,21 +45,15 @@ def get(cls, size=None, scale=(0,1), **kwargs): class REMOTEDATA(DATA): FOLDER = '' - URL = '' - FILE_SIZE = '' + ZENODO_RECORD = '' + ZIP_FILE = '' @classmethod def get(cls, data_dir): return None @classmethod - def _download_and_extract_from_url(cls, data_dir): - with urlopen(cls.URL) as response: - with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile: - zipfile.extractall(path = data_dir) - - @classmethod - def download_data(cls, data_dir): + def download_data(cls, data_dir, prompt=True): ''' Download a dataset from a remote repository @@ -71,14 +64,18 @@ def download_data(cls, data_dir): ''' if os.path.isdir(os.path.join(data_dir, cls.FOLDER)): - print("Dataset already exists in " + data_dir) + print("Dataset folder already exists in " + data_dir) else: - if input("Are you sure you want to download " + cls.FILE_SIZE + " dataset from " + cls.URL + " ? (y/n)") == "y": - print('Downloading dataset from ' + cls.URL) - cls._download_and_extract_from_url(os.path.join(data_dir,cls.FOLDER)) - print('Download complete') - else: + user_input = input("Are you sure you want to download {cls.ZIP_FILE} dataset from Zenodo record {cls.ZENODO_RECORD}? [Y/n]: ") if prompt else 'y' + if user_input.lower() not in ('y', 'yes'): print('Download cancelled') + return False + + zenodo_get([cls.ZENODO_RECORD, '-g', cls.ZIP_FILE, '-o', data_dir]) + with ZipFile(os.path.join(data_dir, cls.ZIP_FILE), 'r') as zip_ref: + zip_ref.extractall(os.path.join(data_dir, cls.FOLDER)) + os.remove(os.path.join(data_dir, cls.ZIP_FILE)) + return True class BOAT(CILDATA): @classmethod @@ -195,15 +192,21 @@ def get(cls, **kwargs): class WALNUT(REMOTEDATA): ''' A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516 + + Example + -------- + >>> data_dir = 'my_PC/data_folder' + >>> dataexample.WALNUT.download_data(data_dir) # download the data + >>> dataexample.WALNUT.get(data_dir) # load the data ''' FOLDER = 'walnut' - URL = 'https://zenodo.org/record/4822516/files/walnut.zip' - FILE_SIZE = '6.4 GB' + ZENODO_RECORD = '4822516' + ZIP_FILE = 'walnut.zip' @classmethod def get(cls, data_dir): ''' - A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516 + Get the microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516 This function returns the raw projection data from the .txrm file Parameters @@ -227,15 +230,21 @@ def get(cls, data_dir): class USB(REMOTEDATA): ''' A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516 + + Example + -------- + >>> data_dir = 'my_PC/data_folder' + >>> dataexample.USB.download_data(data_dir) # download the data + >>> dataexample.USB.get(data_dir) # load the data ''' FOLDER = 'USB' - URL = 'https://zenodo.org/record/4822516/files/usb.zip' - FILE_SIZE = '3.2 GB' + ZENODO_RECORD = '4822516' + ZIP_FILE = 'usb.zip' @classmethod def get(cls, data_dir): ''' - A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516 + Get the microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516 This function returns the raw projection data from the .txrm file Parameters @@ -259,15 +268,21 @@ def get(cls, data_dir): class KORN(REMOTEDATA): ''' A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123 + + Example + -------- + >>> data_dir = 'my_PC/data_folder' + >>> dataexample.KORN.download_data(data_dir) # download the data + >>> dataexample.KORN.get(data_dir) # load the data ''' FOLDER = 'korn' - URL = 'https://zenodo.org/record/6874123/files/korn.zip' - FILE_SIZE = '2.9 GB' + ZENODO_RECORD = '6874123' + ZIP_FILE = 'korn.zip' @classmethod def get(cls, data_dir): ''' - A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123 + Get the microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123 This function returns the raw projection data from the .xtekct file Parameters @@ -279,6 +294,7 @@ def get(cls, data_dir): ------- ImageData The korn dataset + ''' filepath = os.path.join(data_dir, cls.FOLDER, 'Korn i kasse','47209 testscan korn01_recon.xtekct') try: @@ -293,10 +309,40 @@ class SANDSTONE(REMOTEDATA): ''' A synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435 A small subset of the data containing selected projections and 4 slices of the reconstruction + + Example + -------- + >>> data_dir = 'my_PC/data_folder' + >>> dataexample.SANDSTONE.download_data(data_dir) # download the data + >>> dataexample.SANDSTONE.get(data_dir) # load the data ''' FOLDER = 'sandstone' - URL = 'https://zenodo.org/records/4912435/files/small.zip' - FILE_SIZE = '227 MB' + ZENODO_RECORD = '4912435' + ZIP_FILE = 'small.zip' + + @classmethod + def get(cls, data_dir, filename): + ''' + Get the synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435 + A small subset of the data containing selected projections and 4 slices of the reconstruction + Parameters + ---------- + data_dir: str + The path to the directory where the dataset is stored. Data can be downloaded with dataexample.SANDSTONE.download_data(data_dir) + + file: str + The slices or projections to return, specify the path to the file within the data_dir + + Returns + ------- + ImageData + The selected sandstone dataset + ''' + extension = os.path.splitext(filename)[1] + if extension == '.mat': + return loadmat(os.path.join(data_dir,filename)) + raise KeyError(f"Unknown extension: {extension}") + class TestData(object): '''Class to return test data diff --git a/Wrappers/Python/test/test_dataexample.py b/Wrappers/Python/test/test_dataexample.py index b1b2acaa52..4ace99a58f 100644 --- a/Wrappers/Python/test/test_dataexample.py +++ b/Wrappers/Python/test/test_dataexample.py @@ -25,11 +25,11 @@ from testclass import CCPiTestClass import platform import numpy as np -from unittest.mock import patch, MagicMock -from urllib import request +from unittest.mock import patch from zipfile import ZipFile from io import StringIO -from tempfile import NamedTemporaryFile +import uuid +from zenodo_get import zenodo_get initialise_tests() @@ -157,116 +157,89 @@ def test_load_SIMULATED_CONE_BEAM_DATA(self): class TestRemoteData(unittest.TestCase): def setUp(self): - self.data_list = ['WALNUT','USB','KORN','SANDSTONE'] - self.shapes_path = os.path.join(dataexample.CILDATA.data_dir, dataexample.TestData.SHAPES) - def mock_urlopen(self, mock_urlopen, zipped_bytes): - mock_response = MagicMock() - mock_response.read.return_value = zipped_bytes - mock_response.__enter__.return_value = mock_response - mock_urlopen.return_value = mock_response - @unittest.skipIf(platform.system() == 'Windows', "Skip on Windows") - @patch('cil.utilities.dataexample.urlopen') - def test_unzip_remote_data(self, mock_urlopen): - ''' - Test the _download_and_extract_data_from_url function correctly extracts files from a byte string - The zipped byte string is mocked using a temporary local zip file + def mock_zenodo_get(*args): + # mock zenodo_get by making a zip file containing the shapes test data when the function is called + shapes_path = os.path.join(dataexample.CILDATA.data_dir, dataexample.TestData.SHAPES) + with ZipFile(os.path.join(args[0][4], args[0][2]), mode='w') as zip_file: + zip_file.write(shapes_path, arcname=dataexample.TestData.SHAPES) + + + @patch('cil.utilities.dataexample.input', return_value='y') + @patch('cil.utilities.dataexample.zenodo_get', side_effect=mock_zenodo_get) + def test_download_data_input_y(self, mock_zenodo_get, input): ''' - - # create a temporary zip file to test the function - with NamedTemporaryFile(suffix = '.zip') as tf: - tmp_path = os.path.dirname(tf.name) - tmp_dir = os.path.splitext(os.path.basename(tf.name))[0] - with ZipFile(tf.name, mode='w') as zip_file: - zip_file.write(self.shapes_path, arcname=dataexample.TestData.SHAPES) - - with open(tf.name, 'rb') as zip_file: - zipped_bytes = zip_file.read() - - self.mock_urlopen(mock_urlopen, zipped_bytes) - dataexample.REMOTEDATA._download_and_extract_from_url(os.path.join(tmp_path, tmp_dir)) + Test the download_data function, when the user input is 'y' to 'are you sure you want to download data' + The user input to confirm the download is mocked as 'y' + The zip file download is mocked by creating a zip file locally + Test the download_data function correctly extracts files from the zip file + ''' + # create a temporary folder in the CIL data directory + tmp_dir = os.path.join(dataexample.CILDATA.data_dir, str(uuid.uuid4())) + os.makedirs(tmp_dir) + # redirect print output + capturedOutput = StringIO() + sys.stdout = capturedOutput + for data in self.data_list: + test_func = getattr(dataexample, data) + test_func.download_data(tmp_dir) + # Test the data file exists + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, getattr(test_func, 'FOLDER'), dataexample.TestData.SHAPES)), + msg = "Download data test failed with dataset " + data) + # Test the zip file is removed + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, getattr(test_func, 'ZIP_FILE')))) + # return to standard print output + sys.stdout = sys.__stdout__ + shutil.rmtree(tmp_dir) - self.assertTrue(os.path.isfile(os.path.join(tmp_path, tmp_dir, dataexample.TestData.SHAPES))) - if os.path.exists(os.path.join(tmp_path,tmp_dir)): - shutil.rmtree(os.path.join(tmp_path,tmp_dir)) - - @unittest.skipIf(platform.system() == 'Windows', "Skip on Windows") - @patch('cil.utilities.dataexample.input', return_value='n') - @patch('cil.utilities.dataexample.urlopen') - def test_download_data_input_n(self, mock_urlopen, input): + @patch('cil.utilities.dataexample.input', return_value='n') + @patch('cil.utilities.dataexample.zenodo_get', side_effect=mock_zenodo_get) + def test_download_data_input_n(self, mock_zenodo_get, input): ''' Test the download_data function, when the user input is 'n' to 'are you sure you want to download data' - The zipped byte string is mocked using a temporary local zip file ''' - - # create a temporary zip file to test the function - with NamedTemporaryFile(suffix = '.zip') as tf: - tmp_path = os.path.dirname(tf.name) - tmp_dir = os.path.splitext(os.path.basename(tf.name))[0] - with ZipFile(tf.name, mode='w') as zip_file: - zip_file.write(self.shapes_path, arcname=dataexample.TestData.SHAPES) - - with open(tf.name, 'rb') as zip_file: - zipped_bytes = zip_file.read() - - self.mock_urlopen(mock_urlopen, zipped_bytes) - + # create a temporary folder in the CIL data directory + tmp_dir = os.path.join(dataexample.CILDATA.data_dir, str(uuid.uuid4())) + os.makedirs(tmp_dir) for data in self.data_list: # redirect print output - capturedOutput = StringIO() - sys.stdout = capturedOutput + capturedOutput = StringIO() + sys.stdout = capturedOutput test_func = getattr(dataexample, data) - test_func.download_data(os.path.join(tmp_path, tmp_dir)) - self.assertFalse(os.path.isfile(os.path.join(tmp_path, tmp_dir, test_func.FOLDER, dataexample.TestData.SHAPES)), msg = "Failed with dataset " + data) - self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n', msg = "Failed with dataset " + data) + test_func.download_data(tmp_dir) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, getattr(test_func, 'FOLDER'), dataexample.TestData.SHAPES)), + msg = "Download dataset test failed with dataset " + data) + self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n', + msg = "Download dataset test failed with dataset " + data) # return to standard print output sys.stdout = sys.__stdout__ - if os.path.exists(os.path.join(tmp_path,tmp_dir)): - shutil.rmtree(os.path.join(tmp_path,tmp_dir)) - - @unittest.skipIf(platform.system() == 'Windows', "Skip on Windows") - @patch('cil.utilities.dataexample.input', return_value='y') - @patch('cil.utilities.dataexample.urlopen') - def test_download_data_input_y(self, mock_urlopen, input): - ''' - Test the download_data function, when the user input is 'y' to 'are you sure you want to download data' - The zipped byte string is mocked using a temporary local zip file - ''' - - with NamedTemporaryFile(suffix = '.zip') as tf: - tmp_path = os.path.dirname(tf.name) - tmp_dir = os.path.splitext(os.path.basename(tf.name))[0] - with ZipFile(tf.name, mode='w') as zip_file: - zip_file.write(self.shapes_path, arcname=dataexample.TestData.SHAPES) - - with open(tf.name, 'rb') as zip_file: - zipped_bytes = zip_file.read() - - self.mock_urlopen(mock_urlopen, zipped_bytes) - - # redirect print output - capturedOutput = StringIO() - sys.stdout = capturedOutput + # Test the zip file IS created with prompt=False i.e. prompt not used + dataexample.WALNUT.download_data(tmp_dir, prompt=False) + # Test the data file exists + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, dataexample.WALNUT.FOLDER, dataexample.TestData.SHAPES)), + msg = "Download data test failed with dataset " + data) + # Test the zip file is removed + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, dataexample.WALNUT.ZIP_FILE))) - for data in self.data_list: - test_func = getattr(dataexample, data) - test_func.download_data(os.path.join(tmp_path, tmp_dir)) - self.assertTrue(os.path.isfile(os.path.join(tmp_path, tmp_dir, test_func.FOLDER, dataexample.TestData.SHAPES)), msg = "Failed with dataset " + data) - - # return to standard print output - sys.stdout = sys.__stdout__ - - if os.path.exists(os.path.join(tmp_path,tmp_dir)): - shutil.rmtree(os.path.join(tmp_path,tmp_dir)) + shutil.rmtree(tmp_dir) - def test_download_data_bad_URL(self): + @patch('cil.utilities.dataexample.input', return_value='y') + def test_download_data_empty(self, input): ''' - Test an error is raised when _download_and_extract_from_url has an empty URL + Test an error is raised when download_data is used on an empty Zenodo record ''' + remote_data = dataexample.REMOTEDATA + remote_data.ZENODO_RECORD = 'empty' + remote_data.FOLDER = 'empty' + with self.assertRaises(ValueError): - dataexample.REMOTEDATA._download_and_extract_from_url('.') + remote_data.download_data('.') + + def test_a(self): + from cil.utilities.dataexample import WALNUT + \ No newline at end of file diff --git a/docs/source/utilities.rst b/docs/source/utilities.rst index dad9d4966e..816d491b4b 100644 --- a/docs/source/utilities.rst +++ b/docs/source/utilities.rst @@ -53,6 +53,40 @@ Simulated image data :members: :inherited-members: +Remote data +----------- +Remote data classes can be used to access specific datasets from zenodo. These +datasets are not packaged as part of CIL, instead the `download_data(data_dir)` +method can be used to download the dataset to a chosen data directory then loaded +from that data directory using `get(data_dir)`. + +Walnut +------ + +.. autoclass:: cil.utilities.dataexample.WALNUT + :members: + :inherited-members: + +USB +------ + +.. autoclass:: cil.utilities.dataexample.USB + :members: + :inherited-members: + +KORN +------ + +.. autoclass:: cil.utilities.dataexample.KORN + :members: + :inherited-members: + +SANDSTONE +------ +.. autoclass:: cil.utilities.dataexample.SANDSTONE + :members: + :inherited-members: + Image Quality metrics diff --git a/recipe/meta.yaml b/recipe/meta.yaml index 3355fee710..26b339bfad 100644 --- a/recipe/meta.yaml +++ b/recipe/meta.yaml @@ -70,6 +70,7 @@ requirements: - ipp >=2021.10 - tqdm - numba + - zenodo_get >=1.6 #optional packages with version dependancies run_constrained: diff --git a/scripts/create_local_env_for_cil_development.sh b/scripts/create_local_env_for_cil_development.sh index 8e65ca665f..d7564dc7dc 100755 --- a/scripts/create_local_env_for_cil_development.sh +++ b/scripts/create_local_env_for_cil_development.sh @@ -64,6 +64,7 @@ conda_args=(create --name="$name" scikit-image scipy tqdm + zenodo_get'>=1.6' ) if test -n "$cil_ver"; then echo "CIL version $cil_ver" diff --git a/scripts/requirements-test.yml b/scripts/requirements-test.yml index 73e121d81d..17668b73d8 100644 --- a/scripts/requirements-test.yml +++ b/scripts/requirements-test.yml @@ -48,3 +48,4 @@ dependencies: - pywavelets - numba - tqdm + - zenodo_get >=1.6