Skip to content

Commit

Permalink
Merge pull request #74 from quantling/dek/metadata
Browse files Browse the repository at this point in the history
Metadata + dict_ndl improvements
  • Loading branch information
derNarr authored Apr 2, 2017
2 parents 9addaca + d763fb4 commit cb3f019
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 63 deletions.
4 changes: 2 additions & 2 deletions pyndl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__author__ = ('David-Elias Künstle, Lennard Schneider, '
'Konstantin Sering, Marc Weitz')
__author_email__ = '[email protected]'
__version__ = '0.2.4'
__version__ = '0.2.5'
__license__ = 'MIT'
__description__ = ('Naive discriminative learning implements learning and '
'classification models based on the Rescorla-Wagner '
Expand All @@ -28,6 +28,6 @@
:version: %s
:author: %s
:contact: %s
:date: 2017-03-25
:date: 2017-04-02
:copyright: %s
""" % (__description__, __version__, __author__, __author_email__, __license__)
149 changes: 103 additions & 46 deletions pyndl/ndl.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def ndl(event_path, alpha, betas, lambda_=1.0, *,

beta1, beta2 = betas

preprocess.create_binary_event_files(event_path, BINARY_PATH, cue_map,
outcome_map, overwrite=True,
number_of_processes=number_of_threads,
remove_duplicates=remove_duplicates)
number_events = preprocess.create_binary_event_files(event_path, BINARY_PATH, cue_map,
outcome_map, overwrite=True,
number_of_processes=number_of_threads,
remove_duplicates=remove_duplicates)
binary_files = [os.path.join(BINARY_PATH, binary_file)
for binary_file in os.listdir(BINARY_PATH)
if os.path.isfile(os.path.join(BINARY_PATH, binary_file))]
Expand Down Expand Up @@ -190,23 +190,24 @@ def worker():
cpu_time = cpu_time_stop - cpu_time_start
wall_time = wall_time_stop - wall_time_start

attrs = _attributes(event_path, alpha, betas, lambda_, cpu_time, wall_time,
__name__ + "." + ndl.__name__, method=method)

if weights_ini is not None:
attrs_to_be_updated = weights_ini.attrs
for key in attrs_to_be_updated.keys():
attrs_to_be_updated[key] += ' | ' + attrs[key]
attrs = attrs_to_be_updated
else:
attrs_to_be_updated = None

attrs = _attributes(event_path, number_events, alpha, betas, lambda_, cpu_time, wall_time,
__name__ + "." + ndl.__name__, method=method, attrs=attrs_to_be_updated)

# post-processing
weights = xr.DataArray(weights, [('outcomes', outcomes), ('cues', cues)],
attrs=attrs)
return weights


def _attributes(event_path, alpha, betas, lambda_, cpu_time, wall_time, function, method=None):
def _attributes(event_path, number_events, alpha, betas, lambda_, cpu_time,
wall_time, function, method=None, attrs=None):
width = max([len(ss) for ss in (event_path,
str(number_events),
str(alpha),
str(betas),
str(lambda_),
Expand All @@ -219,27 +220,70 @@ def _attributes(event_path, alpha, betas, lambda_, cpu_time, wall_time, function
def format_(ss):
return '{0: <{width}}'.format(ss, width=width)

attrs = {'date': format_(time.strftime("%Y-%m-%d %H:%M:%S")),
'event_path': format_(event_path),
'alpha': format_(str(alpha)),
'betas': format_(str(betas)),
'lambda': format_(str(lambda_)),
'function': format_(function),
'method': format_(str(method)),
'cpu_time': format_(str(cpu_time)),
'wall_time': format_(str(wall_time)),
'hostname': format_(socket.gethostname()),
'username': format_(getpass.getuser()),
'pyndl': format_(__version__),
'numpy': format_(np.__version__),
'pandas': format_(pd.__version__),
'xarray': format_(xr.__version__),
'cython': format_(cython.__version__)}
return attrs
new_attrs = {'date': format_(time.strftime("%Y-%m-%d %H:%M:%S")),
'event_path': format_(event_path),
'number_events': format_(number_events),
'alpha': format_(str(alpha)),
'betas': format_(str(betas)),
'lambda': format_(str(lambda_)),
'function': format_(function),
'method': format_(str(method)),
'cpu_time': format_(str(cpu_time)),
'wall_time': format_(str(wall_time)),
'hostname': format_(socket.gethostname()),
'username': format_(getpass.getuser()),
'pyndl': format_(__version__),
'numpy': format_(np.__version__),
'pandas': format_(pd.__version__),
'xarray': format_(xr.__version__),
'cython': format_(cython.__version__)}

if attrs is not None:
for key in set(attrs.keys()) | set(new_attrs.keys()):
if key in attrs:
old_val = attrs[key]
else:
old_val = ''
if key in new_attrs:
new_val = new_attrs[key]
else:
new_val = format_('')
new_attrs[key] = old_val + ' | ' + new_val
return new_attrs


class WeightDict(defaultdict):
# pylint: disable=missing-docstring

"""
Subclass of defaultdict to represent outcome-cue weights.
Notes
-----
Weight for each outcome-cue combination is 0 per default.
"""

def __init__(self, *args, **kwargs):
super().__init__(lambda: defaultdict(float))

if 'attrs' in kwargs:
self.attrs = kwargs['attrs']
else:
self.attrs = {}

@property
def attrs(self):
return self._attrs

@attrs.setter
def attrs(self, attrs):
self._attrs = OrderedDict(attrs)


def dict_ndl(event_list, alphas, betas, lambda_=1.0, *,
weights=None, inplace=False, remove_duplicates=None, make_data_array=False):
weights=None, inplace=False, remove_duplicates=None,
make_data_array=False):
"""
Calculate the weights for all_outcomes over all events in event_file.
Expand All @@ -260,7 +304,7 @@ def dict_ndl(event_list, alphas, betas, lambda_=1.0, *,
betas : (float, float)
one value for successful prediction (reward) one for punishment
lambda\\_ : float
weights : dict of dicts or None
weights : dict of dicts or xarray.DataArray or None
initial weights
inplace: {True, False}
if True calculates the weightmatrix inplace
Expand Down Expand Up @@ -295,18 +339,28 @@ def dict_ndl(event_list, alphas, betas, lambda_=1.0, *,
if not (remove_duplicates is None or isinstance(remove_duplicates, bool)):
raise ValueError("remove_duplicates must be None, True or False")

if make_data_array:
wall_time_start = time.perf_counter()
cpu_time_start = time.process_time()
if isinstance(event_list, str):
event_path = event_list
else:
event_path = None
wall_time_start = time.perf_counter()
cpu_time_start = time.process_time()
if isinstance(event_list, str):
event_path = event_list
else:
event_path = ""
attrs_to_update = None

# weights can be seen as an infinite outcome by cue matrix
# weights[outcome][cue]
if weights is None:
weights = defaultdict(lambda: defaultdict(float))
weights = WeightDict()
elif isinstance(weights, WeightDict):
attrs_to_update = weights.attrs
elif isinstance(weights, xr.DataArray):
weights_ini = weights
attrs_to_update = weights_ini.attrs
coords = weights_ini.coords
weights = WeightDict()
for oi, outcome in enumerate(coords['outcomes'].values):
for ci, cue in enumerate(coords['cues'].values):
weights[outcome][cue] = weights_ini.item((oi, ci))
elif not isinstance(weights, defaultdict):
raise ValueError('weights needs to be either defaultdict or None')

Expand All @@ -321,8 +375,10 @@ def dict_ndl(event_list, alphas, betas, lambda_=1.0, *,
if isinstance(alphas, float):
alpha = alphas
alphas = defaultdict(lambda: alpha)
number_events = 0

for cues, outcomes in event_list:
number_events += 1
if remove_duplicates is None:
if (len(cues) != len(set(cues)) or
len(outcomes) != len(set(outcomes))):
Expand All @@ -346,19 +402,20 @@ def dict_ndl(event_list, alphas, betas, lambda_=1.0, *,
for cue in cues:
weights[outcome][cue] += alphas[cue] * update

if make_data_array:
cpu_time_stop = time.process_time()
wall_time_stop = time.perf_counter()
cpu_time = cpu_time_stop - cpu_time_start
wall_time = wall_time_stop - wall_time_start

attrs = _attributes(event_path, alphas, betas, lambda_, cpu_time, wall_time,
__name__ + "." + dict_ndl.__name__)
cpu_time_stop = time.process_time()
wall_time_stop = time.perf_counter()
cpu_time = cpu_time_stop - cpu_time_start
wall_time = wall_time_stop - wall_time_start
attrs = _attributes(event_path, number_events, alphas, betas, lambda_, cpu_time, wall_time,
__name__ + "." + dict_ndl.__name__, attrs=attrs_to_update)

if make_data_array:
# post-processing
weights = pd.DataFrame(weights)
# weights.fillna(0.0, inplace=True) # TODO make sure to not remove real NaNs
weights = xr.DataArray(weights.T, dims=('outcomes', 'cues'), attrs=attrs)
else:
weights.attrs = attrs

return weights

Expand Down
27 changes: 23 additions & 4 deletions pyndl/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,10 @@ def write_events(events, filename, *, start=0, stop=4294967295, remove_duplicate
keep multiple instances of the same cue or outcome (this is usually not
preferred!)
Returns
-------
number_events : int
actual number of events written to file
Binary Format
-------------
Expand All @@ -576,7 +580,6 @@ def write_events(events, filename, *, start=0, stop=4294967295, remove_duplicate
StopIteration : events generator is exhausted before stop is reached
"""

with open(filename, "wb") as out_file:
# 8 bytes header
out_file.write(to_bytes(MAGIC_NUMBER))
Expand All @@ -589,6 +592,7 @@ def write_events(events, filename, *, start=0, stop=4294967295, remove_duplicate
out_file.write(to_bytes(n_events_estimate))

n_events = 0

for ii, event in enumerate(events):
if ii < start:
continue
Expand Down Expand Up @@ -625,10 +629,11 @@ def write_events(events, filename, *, start=0, stop=4294967295, remove_duplicate
# the generator was exhausted earlier
out_file.seek(8)
out_file.write(to_bytes(n_events))
raise StopIteration("event generator was exhausted before stop")
raise StopIteration(("event generator was exhausted before stop", n_events))

if n_events == 0:
os.remove(filename)
return n_events


def event_generator(event_file, cue_id_map, outcome_id_map, *, sort_within_event=False):
Expand Down Expand Up @@ -666,7 +671,8 @@ def _job_binary_event_file(*,
remove_duplicates):
# create generator which is not pickable
events = event_generator(event_file, cue_id_map, outcome_id_map, sort_within_event=sort_within_event)
write_events(events, file_name, start=start, stop=stop, remove_duplicates=remove_duplicates)
n_events = write_events(events, file_name, start=start, stop=stop, remove_duplicates=remove_duplicates)
return n_events


def create_binary_event_files(event_file,
Expand Down Expand Up @@ -709,6 +715,10 @@ def create_binary_event_files(event_file,
preferred!)
verbose : bool
Returns
-------
number_events : int
sum of number of events written to binary files
"""

if not os.path.isdir(path_name):
Expand All @@ -724,15 +734,24 @@ def create_binary_event_files(event_file,
if "events_0_" in file_name:
os.remove(os.path.join(path_name, file_name))

number_events = 0

with multiprocessing.Pool(number_of_processes) as pool:

def error_callback(error):
if isinstance(error, StopIteration):

print(error.value)
msg, result = error.value
nonlocal number_events
number_events += result
pool.close()
else:
raise error

def callback(result):
nonlocal number_events
number_events += result
if verbose:
print("finished job")
sys.stdout.flush()
Expand Down Expand Up @@ -778,7 +797,7 @@ def callback(result):
pool.close()
pool.join()
print("finished all jobs.\n")

return number_events

# for example code see function test_preprocess in file
# ./tests/test_preprocess.py.
Loading

0 comments on commit cb3f019

Please sign in to comment.