Skip to content

Commit

Permalink
add option to set temporary directory (#139)
Browse files Browse the repository at this point in the history
* add option to set temporary directory in pyndl.ndl.ndl
  • Loading branch information
kuchenrolle authored and derNarr committed Feb 13, 2018
1 parent 45ddbca commit 388f343
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyndl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
__author__ = ('Konstantin Sering, Marc Weitz, '
'David-Elias Künstle, Lennard Schneider')
__author_email__ = '[email protected]'
__version__ = '0.4.1'
__version__ = '0.4.2'
__license__ = 'MIT'
__description__ = ('Naive discriminative learning implements learning and '
'classification models based on the Rescorla-Wagner '
Expand Down
8 changes: 6 additions & 2 deletions pyndl/ndl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def events_from_file(event_path):
def ndl(events, alpha, betas, lambda_=1.0, *,
method='openmp', weights=None,
number_of_threads=8, len_sublists=10, remove_duplicates=None,
verbose=False):
verbose=False, temporary_directory=None):
"""
Calculate the weights for all_outcomes over all events in event_file
given by the files path.
Expand Down Expand Up @@ -89,6 +89,10 @@ def ndl(events, alpha, betas, lambda_=1.0, *,
preferred!)
verbose : bool
print some output if True.
temporary_directory : str
path to directory to use for storing temporary files created;
if none is provided, the operating system's default will
be used (/tmp on unix)
Returns
-------
Expand Down Expand Up @@ -155,7 +159,7 @@ def ndl(events, alpha, betas, lambda_=1.0, *,

beta1, beta2 = betas

with tempfile.TemporaryDirectory(prefix="pyndl") as binary_path:
with tempfile.TemporaryDirectory(prefix="pyndl", dir=temporary_directory) as binary_path:
number_events = preprocess.create_binary_event_files(events, binary_path, cue_map,
outcome_map, overwrite=True,
number_of_processes=number_of_threads,
Expand Down
8 changes: 8 additions & 0 deletions tests/test_ndl.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def test_exceptions():
ndl.ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, remove_duplicates="magic")
assert e_info == "remove_duplicates must be None, True or False"

with pytest.raises(FileNotFoundError, match="No such file or directory") as e_info:
ndl.ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, temporary_directory="./magic")


def test_continue_learning_dict():
events_simple = pd.read_csv(FILE_PATH_SIMPLE, sep="\t")
Expand Down Expand Up @@ -217,6 +220,11 @@ def test_return_values(result_dict_ndl, result_dict_ndl_data_array, result_ndl_t
assert isinstance(result_ndl_threading, xr.DataArray)


def test_provide_temporary_directory():
with tempfile.TemporaryDirectory(dir=TMP_PATH) as temporary_directory:
ndl.ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, temporary_directory=temporary_directory)


# Test internal consistency

def test_dict_ndl_vs_ndl_threading(result_dict_ndl, result_ndl_threading):
Expand Down

0 comments on commit 388f343

Please sign in to comment.