diff --git a/pyndl/__init__.py b/pyndl/__init__.py index 026b430..64a9b96 100644 --- a/pyndl/__init__.py +++ b/pyndl/__init__.py @@ -17,7 +17,7 @@ __author__ = ('Konstantin Sering, Marc Weitz, ' 'David-Elias Künstle, Lennard Schneider') __author_email__ = 'konstantin.sering@uni-tuebingen.de' -__version__ = '0.4.1' +__version__ = '0.4.2' __license__ = 'MIT' __description__ = ('Naive discriminative learning implements learning and ' 'classification models based on the Rescorla-Wagner ' diff --git a/pyndl/ndl.py b/pyndl/ndl.py index cd02ad3..903d5db 100644 --- a/pyndl/ndl.py +++ b/pyndl/ndl.py @@ -56,7 +56,7 @@ def events_from_file(event_path): def ndl(events, alpha, betas, lambda_=1.0, *, method='openmp', weights=None, number_of_threads=8, len_sublists=10, remove_duplicates=None, - verbose=False): + verbose=False, temporary_directory=None): """ Calculate the weights for all_outcomes over all events in event_file given by the files path. @@ -89,6 +89,10 @@ def ndl(events, alpha, betas, lambda_=1.0, *, preferred!) verbose : bool print some output if True. + temporary_directory : str + path to directory to use for storing temporary files created; + if none is provided, the operating system's default will + be used (/tmp on unix) Returns ------- @@ -155,7 +159,7 @@ def ndl(events, alpha, betas, lambda_=1.0, *, beta1, beta2 = betas - with tempfile.TemporaryDirectory(prefix="pyndl") as binary_path: + with tempfile.TemporaryDirectory(prefix="pyndl", dir=temporary_directory) as binary_path: number_events = preprocess.create_binary_event_files(events, binary_path, cue_map, outcome_map, overwrite=True, number_of_processes=number_of_threads, diff --git a/tests/test_ndl.py b/tests/test_ndl.py index 57170f8..95394b3 100644 --- a/tests/test_ndl.py +++ b/tests/test_ndl.py @@ -124,6 +124,9 @@ def test_exceptions(): ndl.ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, remove_duplicates="magic") assert e_info == "remove_duplicates must be None, True or False" + with pytest.raises(FileNotFoundError, match="No such file or directory") as e_info: + ndl.ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, temporary_directory="./magic") + def test_continue_learning_dict(): events_simple = pd.read_csv(FILE_PATH_SIMPLE, sep="\t") @@ -217,6 +220,11 @@ def test_return_values(result_dict_ndl, result_dict_ndl_data_array, result_ndl_t assert isinstance(result_ndl_threading, xr.DataArray) +def test_provide_temporary_directory(): + with tempfile.TemporaryDirectory(dir=TMP_PATH) as temporary_directory: + ndl.ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, temporary_directory=temporary_directory) + + # Test internal consistency def test_dict_ndl_vs_ndl_threading(result_dict_ndl, result_ndl_threading):