diff --git a/pyndl/activation.py b/pyndl/activation.py index 4d1f3b6..a7858c9 100644 --- a/pyndl/activation.py +++ b/pyndl/activation.py @@ -40,7 +40,7 @@ def activation(event_list, weights, number_of_threads=1, remove_duplicates=None, Returns ------- activations : xarray.DataArray - with dimensions 'events' and 'outcomes'. Contains coords for the outcomes. + with dimensions 'outcomes' and 'events'. Contains coords for the outcomes. returned if weights is instance of xarray.DataArray or @@ -84,7 +84,7 @@ def enforce_no_duplicates(cues): coords={ 'outcomes': outcomes }, - dims=('events', 'outcomes')) + dims=('outcomes', 'events')) elif isinstance(weights, dict): assert number_of_threads == 1, "Estimating activations with multiprocessing is not implemented for dicts." activations = defaultdict(lambda: np.zeros(len(event_cues_list))) @@ -118,7 +118,7 @@ def _run_mp_activation_matrix(event_index, cue_indices): Calculate activation for all outcomes while a event. """ - activations[event_index, :] = weights[:, cue_indices].sum(axis=1) + activations[:, event_index] = weights[:, cue_indices].sum(axis=1) def _activation_matrix(indices_list, weights, number_of_threads): @@ -139,16 +139,16 @@ def _activation_matrix(indices_list, weights, number_of_threads): Returns ------- activation_matrix : numpy.array - estimated activations as matrix with shape (events, outcomes) + estimated activations as matrix with shape (outcomes, events) """ assert number_of_threads >= 1, "Can't run with less than 1 thread" - activations_dim = (len(indices_list), weights.shape[0]) + activations_dim = (weights.shape[0], len(indices_list)) if number_of_threads == 1: activations = np.empty(activations_dim, dtype=np.float64) for row, event_cues in enumerate(indices_list): - activations[row, :] = weights[:, event_cues].sum(axis=1) + activations[:, row] = weights[:, event_cues].sum(axis=1) return activations else: shared_activations = mp.RawArray(ctypes.c_double, int(np.prod(activations_dim))) diff --git a/tests/test_activation.py b/tests/test_activation.py index 0d96df3..848db0b 100644 --- a/tests/test_activation.py +++ b/tests/test_activation.py @@ -46,7 +46,7 @@ def test_activation_matrix(): (['c1', 'c3'], []), (['c2'], []), (['c1', 'c1'], [])] - reference_activations = np.array([[1, 1], [0, 1], [1, 0], [0, 1]]) + reference_activations = np.array([[1, 1], [0, 1], [1, 0], [0, 1]]).T with pytest.raises(ValueError): activations = activation(events, weights, number_of_threads=1) @@ -70,7 +70,7 @@ def test_ignore_missing_cues(): (['c1', 'c3'], []), (['c2', 'c4'], []), (['c1', 'c1'], [])] - reference_activations = np.array([[1, 1], [0, 1], [1, 0], [0, 1]]) + reference_activations = np.array([[1, 1], [0, 1], [1, 0], [0, 1]]).T with pytest.raises(KeyError): activations = activation(events, weights, number_of_threads=1,