Skip to content

Commit

Permalink
General improvements (#196)
Browse files Browse the repository at this point in the history
* changes API and adds some general code improvements (#166)

* renames number_of_processes to n_jobs (fixes #103)

* exposes make_data_array as ndl.data_array(...) and adds tests (fixes #122)

* adds learning events from generators in ndl.ndl (fixes #165)

* changes tuples into namedtuples for return values (fixes #108)

* adds sanity check for too many cues or outcomes (see #169 for details).

* gets rid of DeprecationWarning

* checks if outcomes_vectors, cue_vectors and weights are c_contiguous before passing it to cython functions (fixes #192)

* adds conditional imports to support python 3.7 for the time being

* merges develop and deletes it afterwards (fixes #193)

Co-authored-by: Marc Weitz <[email protected]>
  • Loading branch information
derNarr and Trybnetic authored Jun 21, 2021
1 parent bb43350 commit 1b4e2c7
Show file tree
Hide file tree
Showing 15 changed files with 275 additions and 135 deletions.
2 changes: 1 addition & 1 deletion doc/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ former one using `openMP <http://www.openmp.org/>`_ and therefore being expected
to be much faster when analyzing larger data. Besides, you can set three
technical arguments which we will not change here:

1. ``number_of_threads`` (int) giving the number of threads in which the job
1. ``n_jobs`` (int) giving the number of threads in which the job
should be executed (default=2)
2. ``sequence`` (int) giving the length of sublists generated from all outcomes
(default=10)
Expand Down
25 changes: 18 additions & 7 deletions pyndl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,27 @@
import os
import sys
import multiprocessing as mp
from pip._vendor import pkg_resources
try:
from importlib.metadata import requires
except ModuleNotFoundError: # python 3.7 and before
requires = None
try:
from packaging.requirements import Requirement
except ModuleNotFoundError: # this should only happend during setup phase
Requirement = None


__author__ = ('Konstantin Sering, Marc Weitz, '
'David-Elias Künstle, Lennard Schneider, '
'Elnaz Shafaei-Bajestan')
__author_email__ = '[email protected]'
__version__ = '0.8.2'
__version__ = '0.8.1'
__license__ = 'MIT'
__description__ = ('Naive discriminative learning implements learning and '
'classification models based on the Rescorla-Wagner '
'equations.')
__classifiers__ = [
'Development Status :: 3 - Alpha',
'Development Status :: 4 - Beta',
'Environment :: Console',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
Expand All @@ -45,8 +52,9 @@ def sysinfo():
"""
Prints system the dependency information
"""
pyndl = pkg_resources.working_set.by_key["pyndl"]
dependencies = [r.project_name for r in pyndl.requires()]
if requires:
dependencies = [Requirement(req).name for req in requires('pyndl')
if not Requirement(req).marker]

header = ("Pyndl Information\n"
"=================\n\n")
Expand Down Expand Up @@ -78,7 +86,10 @@ def sysinfo():
deps = ("Dependencies\n"
"------------\n")

deps += "\n".join("{pkg.__name__}: {pkg.__version__}".format(pkg=__import__(dep))
for dep in dependencies)
if requires:
deps += "\n".join("{pkg.__name__}: {pkg.__version__}".format(pkg=__import__(dep))
for dep in dependencies)
else:
deps = 'You need Python 3.8 or higher to show dependencies.'

print(header + general + osinfo + deps)
25 changes: 16 additions & 9 deletions pyndl/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import multiprocessing as mp
import ctypes
from collections import defaultdict, OrderedDict
import warnings

import numpy as np
import xarray as xr
Expand All @@ -17,7 +18,8 @@


# pylint: disable=W0621
def activation(events, weights, number_of_threads=1, remove_duplicates=None, ignore_missing_cues=False):
def activation(events, weights, *, n_jobs=1, number_of_threads=None,
remove_duplicates=None, ignore_missing_cues=False):
"""
Estimate activations for given events in event file and outcome-cue weights.
Expand All @@ -31,7 +33,7 @@ def activation(events, weights, number_of_threads=1, remove_duplicates=None, ign
weights : xarray.DataArray or dict[dict[float]]
the xarray.DataArray needs to have the dimensions 'outcomes' and 'cues'
the dictionaries hold weight[outcome][cue].
number_of_threads : int
n_jobs : int
a integer giving the number of threads in which the job should
executed
remove_duplicates : {None, True, False}
Expand All @@ -58,6 +60,11 @@ def activation(events, weights, number_of_threads=1, remove_duplicates=None, ign
returned if weights is instance of dict
"""
if number_of_threads is not None:
warnings.warn("Parameter `number_of_threads` is renamed to `n_jobs`. The old name "
"will stop working with v0.9.0.",
DeprecationWarning, stacklevel=2)
n_jobs = number_of_threads
if isinstance(events, str):
events = io.events_from_file(events)

Expand Down Expand Up @@ -87,14 +94,14 @@ def check_no_duplicates(cues):
for event_cues in events)
# pylint: disable=W0621
activations = _activation_matrix(list(event_cue_indices_list),
weights.values, number_of_threads)
weights.values, n_jobs)
return xr.DataArray(activations,
coords={
'outcomes': outcomes
},
dims=('outcomes', 'events'))
elif isinstance(weights, dict):
assert number_of_threads == 1, "Estimating activations with multiprocessing is not implemented for dicts."
assert n_jobs == 1, "Estimating activations with multiprocessing is not implemented for dicts."
activations = defaultdict(lambda: np.zeros(len(events)))
events = list(events)
for outcome, cue_dict in weights.items():
Expand Down Expand Up @@ -130,7 +137,7 @@ def _run_mp_activation_matrix(event_index, cue_indices):
activations[:, event_index] = weights[:, cue_indices].sum(axis=1)


def _activation_matrix(indices_list, weights, number_of_threads):
def _activation_matrix(indices_list, weights, n_jobs):
"""
Estimate activation for indices in weights
Expand All @@ -143,18 +150,18 @@ def _activation_matrix(indices_list, weights, number_of_threads):
events as cue indices in weights
weights : numpy.array
weight matrix with shape (outcomes, cues)
number_of_threads : int
n_jobs : int
Returns
-------
activation_matrix : numpy.array
estimated activations as matrix with shape (outcomes, events)
"""
assert number_of_threads >= 1, "Can't run with less than 1 thread"
assert n_jobs >= 1, "Can't run with less than 1 thread"

activations_dim = (weights.shape[0], len(indices_list))
if number_of_threads == 1:
if n_jobs == 1:
activations = np.empty(activations_dim, dtype=np.float64)
for row, event_cues in enumerate(indices_list):
activations[:, row] = weights[:, event_cues].sum(axis=1)
Expand All @@ -164,7 +171,7 @@ def _activation_matrix(indices_list, weights, number_of_threads):
weights = np.ascontiguousarray(weights)
shared_weights = mp.sharedctypes.copy(np.ctypeslib.as_ctypes(np.float64(weights)))
initargs = (shared_weights, weights.shape, shared_activations, activations_dim)
with mp.Pool(number_of_threads, initializer=_init_mp_activation_matrix, initargs=initargs) as pool:
with mp.Pool(n_jobs, initializer=_init_mp_activation_matrix, initargs=initargs) as pool:
pool.starmap(_run_mp_activation_matrix, enumerate(indices_list))
activations = np.ctypeslib.as_array(shared_activations)
activations.shape = activations_dim
Expand Down
41 changes: 28 additions & 13 deletions pyndl/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@
"""
# pylint: disable=redefined-outer-name, invalid-name

from collections import Counter
from collections import Counter, namedtuple
import gzip
import itertools
import multiprocessing
import sys
import warnings


CuesOutcomes = namedtuple('CuesOutcomes', 'n_events, cues, outcomes')
WordsSymbols = namedtuple('WordsSymbols', 'words, symbols')


def _job_cues_outcomes(event_file_name, start, step, verbose=False):
Expand Down Expand Up @@ -46,24 +51,29 @@ def _job_cues_outcomes(event_file_name, start, step, verbose=False):


def cues_outcomes(event_file_name,
*, number_of_processes=2, verbose=False):
*, n_jobs=2, number_of_processes=None, verbose=False):
"""
Counts cues and outcomes in event_file_name using number_of_processes
Counts cues and outcomes in event_file_name using n_jobs
processes.
Returns
-------
(n_events, cues, outcomes) : (int, collections.Counter, collections.Counter)
"""
with multiprocessing.Pool(number_of_processes) as pool:
step = number_of_processes
if number_of_processes is not None:
warnings.warn("Parameter `number_of_processes` is renamed to `n_jobs`. The old name "
"will stop working with v0.9.0.",
DeprecationWarning, stacklevel=2)
n_jobs = number_of_processes
with multiprocessing.Pool(n_jobs) as pool:
step = n_jobs
results = pool.starmap(_job_cues_outcomes,
((event_file_name,
start,
step,
verbose)
for start in range(number_of_processes)))
for start in range(n_jobs)))
n_events = 0
cues = Counter()
outcomes = Counter()
Expand All @@ -75,7 +85,7 @@ def cues_outcomes(event_file_name,
if verbose:
print('\n...counting done.')

return n_events, cues, outcomes
return CuesOutcomes(n_events, cues, outcomes)


def _job_words_symbols(corpus_file_name, start, step, lower_case=False,
Expand Down Expand Up @@ -117,25 +127,30 @@ def _job_words_symbols(corpus_file_name, start, step, lower_case=False,


def words_symbols(corpus_file_name,
*, number_of_processes=2, lower_case=False, verbose=False):
*, n_jobs=2, number_of_processes=None, lower_case=False, verbose=False):
"""
Counts words and symbols in corpus_file_name using number_of_processes
Counts words and symbols in corpus_file_name using n_jobs
processes.
Returns
-------
(words, symbols) : (collections.Counter, collections.Counter)
"""
with multiprocessing.Pool(number_of_processes) as pool:
step = number_of_processes
if number_of_processes is not None:
warnings.warn("Parameter `number_of_processes` is renamed to `n_jobs`. The old name "
"will stop working with v0.9.0.",
DeprecationWarning, stacklevel=2)
n_jobs = number_of_processes
with multiprocessing.Pool(n_jobs) as pool:
step = n_jobs
results = pool.starmap(_job_words_symbols, ((corpus_file_name,
start,
step,
lower_case,
verbose)
for start in
range(number_of_processes)))
range(n_jobs)))
words = Counter()
symbols = Counter()
for words_process, symbols_process in results:
Expand All @@ -145,7 +160,7 @@ def words_symbols(corpus_file_name,
if verbose:
print('\n...counting done.')

return words, symbols
return WordsSymbols(words, symbols)


def save_counter(counter, filename, *, header='key\tfreq\n'):
Expand Down
Loading

0 comments on commit 1b4e2c7

Please sign in to comment.