Skip to content

Commit

Permalink
add method for getting the data corresponding to single event
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Aug 8, 2024
1 parent 8f64a6d commit 0db881e
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 13 deletions.
2 changes: 1 addition & 1 deletion spec/ndx-binned-spikes.extensions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ groups:
doc: The binned data. It should be an array whose first dimension is the number
of units, the second dimension is the total number of events of all stimuli,
and the third dimension is the number of bins.
- name: event_index
- name: event_indices
dtype: int64
dims:
- number_of_events
Expand Down
11 changes: 9 additions & 2 deletions src/pynwb/ndx_binned_spikes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class AggregatedBinnedAlignedSpikes(NWBDataInterface):
),
},
{
"name": "event_index",
"name": "event_indices",
"type": "array_data",
"doc": "The timestamps at which the events occurred.",
"shape": (None,),
Expand All @@ -181,7 +181,14 @@ def __init__(self, **kwargs):
for key in kwargs:
setattr(self, key, kwargs[key])


# Should this return an instance of BinnedAlignedSpikes or just the data as it is?
# Going with the simple one for the moment
def get_data_for_stimuli(self, event_index):

mask = self.event_indices == event_index
binned_spikes_for_unit = self.data[:, mask, :]

return binned_spikes_for_unit


# Remove these functions from the package
Expand Down
29 changes: 22 additions & 7 deletions src/pynwb/tests/test_aggregated_binned_aligned_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def setUp(self):
self.milliseconds_from_event_to_first_bin = -100.0

# Two units in total and 4 bins, and event with two timestamps
data_for_first_event_instance = np.array(
self.data_for_first_event_instance = np.array(
[
# Unit 1 data
[
Expand All @@ -36,7 +36,7 @@ def setUp(self):
)

# Also two units and 4 bins but this event appeared three times
data_for_second_event_instance = np.array(
self.data_for_second_event_instance = np.array(
[
# Unit 1 data
[
Expand All @@ -53,14 +53,14 @@ def setUp(self):
]
)

self.event_index = np.concatenate(
self.event_indices = np.concatenate(
[
np.full(instance.shape[1], i)
for i, instance in enumerate([data_for_first_event_instance, data_for_second_event_instance])
for i, instance in enumerate([self.data_for_first_event_instance, self.data_for_second_event_instance])
]
)

self.data = np.concatenate([data_for_first_event_instance, data_for_second_event_instance], axis=1)
self.data = np.concatenate([self.data_for_first_event_instance, self.data_for_second_event_instance], axis=1)


def test_constructor(self):
Expand All @@ -70,11 +70,11 @@ def test_constructor(self):
bin_width_in_milliseconds=self.bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin,
data=self.data,
event_index=self.event_index,
event_indices=self.event_indices,
)

np.testing.assert_array_equal(aggregated_binnned_align_spikes.data, self.data)
np.testing.assert_array_equal(aggregated_binnned_align_spikes.event_index, self.event_index)
np.testing.assert_array_equal(aggregated_binnned_align_spikes.event_indices, self.event_indices)
self.assertEqual(aggregated_binnned_align_spikes.bin_width_in_milliseconds, self.bin_width_in_milliseconds)
self.assertEqual(
aggregated_binnned_align_spikes.milliseconds_from_event_to_first_bin, self.milliseconds_from_event_to_first_bin
Expand All @@ -83,3 +83,18 @@ def test_constructor(self):
self.assertEqual(aggregated_binnned_align_spikes.data.shape[0], self.number_of_units)
self.assertEqual(aggregated_binnned_align_spikes.data.shape[1], self.number_of_events)
self.assertEqual(aggregated_binnned_align_spikes.data.shape[2], self.number_of_bins)


def test_get_single_event_data_method(self):

aggregated_binnned_align_spikes = AggregatedBinnedAlignedSpikes(
bin_width_in_milliseconds=self.bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin,
data=self.data,
event_indices=self.event_indices,
)


data_for_stimuli_1 = aggregated_binnned_align_spikes.get_data_for_stimuli(event_index=0)

np.testing.assert_allclose(data_for_stimuli_1, self.data_for_first_event_instance)
6 changes: 3 additions & 3 deletions src/spec/create_extension_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def main():
dims=["num_units", "number_of_events", "number_of_bins"],
)

event_index = NWBDatasetSpec(
name="event_index",
event_indices = NWBDatasetSpec(
name="event_indices",
doc="The index of the event that each row of the data corresponds to.",
dtype="int64",
shape=[None],
Expand All @@ -115,7 +115,7 @@ def main():
neurodata_type_inc="NWBDataInterface",
default_name="AggregatedBinnedAlignedSpikes",
doc="A data interface for aggregated binned spike data aligned to an multiple events (e.g. a stimuli or the beginning of a trial).",
datasets=[aggregated_binned_aligned_spikes_data, event_index, units_region],
datasets=[aggregated_binned_aligned_spikes_data, event_indices, units_region],
attributes=[
NWBAttributeSpec(
name="name",
Expand Down

0 comments on commit 0db881e

Please sign in to comment.