diff --git a/pyndl/__init__.py b/pyndl/__init__.py index 3e3899d..9c22036 100644 --- a/pyndl/__init__.py +++ b/pyndl/__init__.py @@ -1,7 +1,7 @@ __author__ = ('David-Elias Künstle, Lennard Schneider, ' 'Konstantin Sering, Marc Weitz') __author_email__ = 'konstantin.sering@uni-tuebingen.de' -__version__ = '0.2.4' +__version__ = '0.2.5' __license__ = 'MIT' __description__ = ('Naive discriminative learning implements learning and ' 'classification models based on the Rescorla-Wagner ' @@ -28,6 +28,6 @@ :version: %s :author: %s :contact: %s -:date: 2017-03-25 +:date: 2017-04-02 :copyright: %s """ % (__description__, __version__, __author__, __author_email__, __license__) diff --git a/pyndl/ndl.py b/pyndl/ndl.py index 444d971..7a133df 100644 --- a/pyndl/ndl.py +++ b/pyndl/ndl.py @@ -142,10 +142,10 @@ def ndl(event_path, alpha, betas, lambda_=1.0, *, beta1, beta2 = betas - preprocess.create_binary_event_files(event_path, BINARY_PATH, cue_map, - outcome_map, overwrite=True, - number_of_processes=number_of_threads, - remove_duplicates=remove_duplicates) + number_events = preprocess.create_binary_event_files(event_path, BINARY_PATH, cue_map, + outcome_map, overwrite=True, + number_of_processes=number_of_threads, + remove_duplicates=remove_duplicates) binary_files = [os.path.join(BINARY_PATH, binary_file) for binary_file in os.listdir(BINARY_PATH) if os.path.isfile(os.path.join(BINARY_PATH, binary_file))] @@ -190,14 +190,13 @@ def worker(): cpu_time = cpu_time_stop - cpu_time_start wall_time = wall_time_stop - wall_time_start - attrs = _attributes(event_path, alpha, betas, lambda_, cpu_time, wall_time, - __name__ + "." + ndl.__name__, method=method) - if weights_ini is not None: attrs_to_be_updated = weights_ini.attrs - for key in attrs_to_be_updated.keys(): - attrs_to_be_updated[key] += ' | ' + attrs[key] - attrs = attrs_to_be_updated + else: + attrs_to_be_updated = None + + attrs = _attributes(event_path, number_events, alpha, betas, lambda_, cpu_time, wall_time, + __name__ + "." + ndl.__name__, method=method, attrs=attrs_to_be_updated) # post-processing weights = xr.DataArray(weights, [('outcomes', outcomes), ('cues', cues)], @@ -205,8 +204,10 @@ def worker(): return weights -def _attributes(event_path, alpha, betas, lambda_, cpu_time, wall_time, function, method=None): +def _attributes(event_path, number_events, alpha, betas, lambda_, cpu_time, + wall_time, function, method=None, attrs=None): width = max([len(ss) for ss in (event_path, + str(number_events), str(alpha), str(betas), str(lambda_), @@ -219,27 +220,70 @@ def _attributes(event_path, alpha, betas, lambda_, cpu_time, wall_time, function def format_(ss): return '{0: <{width}}'.format(ss, width=width) - attrs = {'date': format_(time.strftime("%Y-%m-%d %H:%M:%S")), - 'event_path': format_(event_path), - 'alpha': format_(str(alpha)), - 'betas': format_(str(betas)), - 'lambda': format_(str(lambda_)), - 'function': format_(function), - 'method': format_(str(method)), - 'cpu_time': format_(str(cpu_time)), - 'wall_time': format_(str(wall_time)), - 'hostname': format_(socket.gethostname()), - 'username': format_(getpass.getuser()), - 'pyndl': format_(__version__), - 'numpy': format_(np.__version__), - 'pandas': format_(pd.__version__), - 'xarray': format_(xr.__version__), - 'cython': format_(cython.__version__)} - return attrs + new_attrs = {'date': format_(time.strftime("%Y-%m-%d %H:%M:%S")), + 'event_path': format_(event_path), + 'number_events': format_(number_events), + 'alpha': format_(str(alpha)), + 'betas': format_(str(betas)), + 'lambda': format_(str(lambda_)), + 'function': format_(function), + 'method': format_(str(method)), + 'cpu_time': format_(str(cpu_time)), + 'wall_time': format_(str(wall_time)), + 'hostname': format_(socket.gethostname()), + 'username': format_(getpass.getuser()), + 'pyndl': format_(__version__), + 'numpy': format_(np.__version__), + 'pandas': format_(pd.__version__), + 'xarray': format_(xr.__version__), + 'cython': format_(cython.__version__)} + + if attrs is not None: + for key in set(attrs.keys()) | set(new_attrs.keys()): + if key in attrs: + old_val = attrs[key] + else: + old_val = '' + if key in new_attrs: + new_val = new_attrs[key] + else: + new_val = format_('') + new_attrs[key] = old_val + ' | ' + new_val + return new_attrs + + +class WeightDict(defaultdict): + # pylint: disable=missing-docstring + + """ + Subclass of defaultdict to represent outcome-cue weights. + + Notes + ----- + Weight for each outcome-cue combination is 0 per default. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(lambda: defaultdict(float)) + + if 'attrs' in kwargs: + self.attrs = kwargs['attrs'] + else: + self.attrs = {} + + @property + def attrs(self): + return self._attrs + + @attrs.setter + def attrs(self, attrs): + self._attrs = OrderedDict(attrs) def dict_ndl(event_list, alphas, betas, lambda_=1.0, *, - weights=None, inplace=False, remove_duplicates=None, make_data_array=False): + weights=None, inplace=False, remove_duplicates=None, + make_data_array=False): """ Calculate the weights for all_outcomes over all events in event_file. @@ -260,7 +304,7 @@ def dict_ndl(event_list, alphas, betas, lambda_=1.0, *, betas : (float, float) one value for successful prediction (reward) one for punishment lambda\\_ : float - weights : dict of dicts or None + weights : dict of dicts or xarray.DataArray or None initial weights inplace: {True, False} if True calculates the weightmatrix inplace @@ -295,18 +339,28 @@ def dict_ndl(event_list, alphas, betas, lambda_=1.0, *, if not (remove_duplicates is None or isinstance(remove_duplicates, bool)): raise ValueError("remove_duplicates must be None, True or False") - if make_data_array: - wall_time_start = time.perf_counter() - cpu_time_start = time.process_time() - if isinstance(event_list, str): - event_path = event_list - else: - event_path = None + wall_time_start = time.perf_counter() + cpu_time_start = time.process_time() + if isinstance(event_list, str): + event_path = event_list + else: + event_path = "" + attrs_to_update = None # weights can be seen as an infinite outcome by cue matrix # weights[outcome][cue] if weights is None: - weights = defaultdict(lambda: defaultdict(float)) + weights = WeightDict() + elif isinstance(weights, WeightDict): + attrs_to_update = weights.attrs + elif isinstance(weights, xr.DataArray): + weights_ini = weights + attrs_to_update = weights_ini.attrs + coords = weights_ini.coords + weights = WeightDict() + for oi, outcome in enumerate(coords['outcomes'].values): + for ci, cue in enumerate(coords['cues'].values): + weights[outcome][cue] = weights_ini.item((oi, ci)) elif not isinstance(weights, defaultdict): raise ValueError('weights needs to be either defaultdict or None') @@ -321,8 +375,10 @@ def dict_ndl(event_list, alphas, betas, lambda_=1.0, *, if isinstance(alphas, float): alpha = alphas alphas = defaultdict(lambda: alpha) + number_events = 0 for cues, outcomes in event_list: + number_events += 1 if remove_duplicates is None: if (len(cues) != len(set(cues)) or len(outcomes) != len(set(outcomes))): @@ -346,19 +402,20 @@ def dict_ndl(event_list, alphas, betas, lambda_=1.0, *, for cue in cues: weights[outcome][cue] += alphas[cue] * update - if make_data_array: - cpu_time_stop = time.process_time() - wall_time_stop = time.perf_counter() - cpu_time = cpu_time_stop - cpu_time_start - wall_time = wall_time_stop - wall_time_start - - attrs = _attributes(event_path, alphas, betas, lambda_, cpu_time, wall_time, - __name__ + "." + dict_ndl.__name__) + cpu_time_stop = time.process_time() + wall_time_stop = time.perf_counter() + cpu_time = cpu_time_stop - cpu_time_start + wall_time = wall_time_stop - wall_time_start + attrs = _attributes(event_path, number_events, alphas, betas, lambda_, cpu_time, wall_time, + __name__ + "." + dict_ndl.__name__, attrs=attrs_to_update) + if make_data_array: # post-processing weights = pd.DataFrame(weights) # weights.fillna(0.0, inplace=True) # TODO make sure to not remove real NaNs weights = xr.DataArray(weights.T, dims=('outcomes', 'cues'), attrs=attrs) + else: + weights.attrs = attrs return weights diff --git a/pyndl/preprocess.py b/pyndl/preprocess.py index 6b68d84..0b083ce 100644 --- a/pyndl/preprocess.py +++ b/pyndl/preprocess.py @@ -556,6 +556,10 @@ def write_events(events, filename, *, start=0, stop=4294967295, remove_duplicate keep multiple instances of the same cue or outcome (this is usually not preferred!) + Returns + ------- + number_events : int + actual number of events written to file Binary Format ------------- @@ -576,7 +580,6 @@ def write_events(events, filename, *, start=0, stop=4294967295, remove_duplicate StopIteration : events generator is exhausted before stop is reached """ - with open(filename, "wb") as out_file: # 8 bytes header out_file.write(to_bytes(MAGIC_NUMBER)) @@ -589,6 +592,7 @@ def write_events(events, filename, *, start=0, stop=4294967295, remove_duplicate out_file.write(to_bytes(n_events_estimate)) n_events = 0 + for ii, event in enumerate(events): if ii < start: continue @@ -625,10 +629,11 @@ def write_events(events, filename, *, start=0, stop=4294967295, remove_duplicate # the generator was exhausted earlier out_file.seek(8) out_file.write(to_bytes(n_events)) - raise StopIteration("event generator was exhausted before stop") + raise StopIteration(("event generator was exhausted before stop", n_events)) if n_events == 0: os.remove(filename) + return n_events def event_generator(event_file, cue_id_map, outcome_id_map, *, sort_within_event=False): @@ -666,7 +671,8 @@ def _job_binary_event_file(*, remove_duplicates): # create generator which is not pickable events = event_generator(event_file, cue_id_map, outcome_id_map, sort_within_event=sort_within_event) - write_events(events, file_name, start=start, stop=stop, remove_duplicates=remove_duplicates) + n_events = write_events(events, file_name, start=start, stop=stop, remove_duplicates=remove_duplicates) + return n_events def create_binary_event_files(event_file, @@ -709,6 +715,10 @@ def create_binary_event_files(event_file, preferred!) verbose : bool + Returns + ------- + number_events : int + sum of number of events written to binary files """ if not os.path.isdir(path_name): @@ -724,15 +734,24 @@ def create_binary_event_files(event_file, if "events_0_" in file_name: os.remove(os.path.join(path_name, file_name)) + number_events = 0 + with multiprocessing.Pool(number_of_processes) as pool: def error_callback(error): if isinstance(error, StopIteration): + + print(error.value) + msg, result = error.value + nonlocal number_events + number_events += result pool.close() else: raise error def callback(result): + nonlocal number_events + number_events += result if verbose: print("finished job") sys.stdout.flush() @@ -778,7 +797,7 @@ def callback(result): pool.close() pool.join() print("finished all jobs.\n") - + return number_events # for example code see function test_preprocess in file # ./tests/test_preprocess.py. diff --git a/tests/test_ndl.py b/tests/test_ndl.py index a4b6bc5..2e754bd 100644 --- a/tests/test_ndl.py +++ b/tests/test_ndl.py @@ -49,6 +49,11 @@ def result_dict_ndl(): return ndl.dict_ndl(FILE_PATH_SIMPLE, ALPHA, BETAS) +@pytest.fixture(scope='module') +def result_dict_ndl_generator(): + return ndl.dict_ndl(ndl.events(FILE_PATH_SIMPLE), ALPHA, BETAS) + + @pytest.fixture(scope='module') def result_dict_ndl_data_array(): return ndl.dict_ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, make_data_array=True) @@ -153,6 +158,18 @@ def test_continue_learning_dict(): assert result_part != result +def test_continue_learning_dict_ndl_data_array(result_dict_ndl, result_dict_ndl_data_array): + continue_from_dict = ndl.dict_ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, + weights=result_dict_ndl) + continue_from_data_array = ndl.dict_ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, + weights=result_dict_ndl_data_array) + unequal, unequal_ratio = compare_arrays(FILE_PATH_SIMPLE, + continue_from_dict, + continue_from_data_array) + print('%.2f ratio unequal' % unequal_ratio) + assert len(unequal) == 0 + + def test_continue_learning(result_continue_learning, result_ndl_openmp): assert result_continue_learning.shape == result_ndl_openmp.shape @@ -168,7 +185,7 @@ def test_continue_learning(result_continue_learning, result_ndl_openmp): def test_save_to_netcdf4(result_ndl_openmp): - weights = result_ndl_openmp + weights = result_ndl_openmp.copy() # avoids changing shared test data path = os.path.join(TMP_PATH, "weights.nc") weights.to_netcdf(path) weights_read = xr.open_dataarray(path) @@ -202,6 +219,13 @@ def test_dict_ndl_vs_ndl_threading(result_dict_ndl, result_ndl_threading): assert len(unequal) == 0 +def test_dict_ndl_vs_dict_ndl_generator(result_dict_ndl, result_dict_ndl_generator): + unequal, unequal_ratio = compare_arrays(FILE_PATH_SIMPLE, result_dict_ndl, + result_dict_ndl_generator) + print('%.2f ratio unequal' % unequal_ratio) + assert len(unequal) == 0 + + def test_dict_ndl_data_array_vs_ndl_threading(result_ndl_threading): result_dict_ndl = ndl.dict_ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, make_data_array=True) @@ -229,15 +253,20 @@ def test_dict_ndl_vs_ndl_openmp(result_dict_ndl, result_ndl_openmp): assert len(unequal) == 0 -def test_meta_data(result_dict_ndl_data_array, result_ndl_openmp, result_ndl_threading): +def test_meta_data(result_dict_ndl, result_dict_ndl_data_array, result_ndl_openmp, result_ndl_threading): attributes = {'cython', 'cpu_time', 'hostname', 'xarray', 'wall_time', - 'event_path', 'username', 'method', 'date', 'numpy', + 'event_path', 'number_events', 'username', 'method', 'date', 'numpy', 'betas', 'lambda', 'pyndl', 'alpha', 'pandas', 'method', 'function'} - - assert set(result_ndl_openmp.attrs.keys()) == attributes - assert set(result_ndl_threading.attrs.keys()) == attributes - assert set(result_dict_ndl_data_array.attrs.keys()) == attributes + results = [result_dict_ndl, result_dict_ndl_data_array, result_ndl_threading, result_ndl_openmp] + for result in results: + assert set(result.attrs.keys()) == attributes + + assert int(result_dict_ndl_data_array.attrs['number_events']) > 0 + assert len(set( + [result.attrs['number_events'].strip() + for result in results] + )) == 1 # Test against external ndl2 results diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 4e79e69..c92ad50 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -189,12 +189,14 @@ def test_write_events(): # start stop events = event_generator(event_file, cue_id_map, outcome_id_map, sort_within_event=True) - write_events(events, file_name, start=10, stop=20, remove_duplicates=True) + n_events = write_events(events, file_name, start=10, stop=20, remove_duplicates=True) + assert n_events == 10 os.remove(file_name) # no events events = event_generator(event_file, cue_id_map, outcome_id_map, sort_within_event=True) - write_events(events, file_name, start=100000, stop=100010, remove_duplicates=True) + n_events = write_events(events, file_name, start=100000, stop=100010, remove_duplicates=True) + assert n_events == 0 _job_binary_event_file(file_name=file_name, event_file=event_file, cue_id_map=cue_id_map, @@ -235,11 +237,14 @@ def test_read_binary_file(): cue_id_map = OrderedDict(((cue, ii) for ii, cue in enumerate(cues.keys()))) outcome_id_map = OrderedDict(((outcome, ii) for ii, outcome in enumerate(outcomes.keys()))) - create_binary_event_files(abs_file_path, abs_binary_path, cue_id_map, - outcome_id_map, overwrite=True, remove_duplicates=False) + number_events = create_binary_event_files(abs_file_path, abs_binary_path, cue_id_map, + outcome_id_map, overwrite=True, remove_duplicates=False) bin_events = read_binary_file(abs_binary_file_path) events = ndl.events(abs_file_path) + events_dup = ndl.events(abs_file_path) + + assert number_events == len(list(events_dup)) for event, bin_event in zip(events, bin_events): cues, outcomes = event