Skip to content

Commit

Permalink
Prevent error from unitless spike group (LorenFrankLab#1083)
Browse files Browse the repository at this point in the history
* prevent error from unitless spike group

* fix 1077

* update changelog
  • Loading branch information
samuelbray32 authored Aug 29, 2024
1 parent adfed75 commit d4dbc23
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@

- Disable populate transaction protection for long-populating tables #1066

### Pipelines

- Decoding
- Fix edge case errors in spike time loading #1083

## [0.5.3] (August 27, 2024)

### Infrastructure
Expand Down
11 changes: 8 additions & 3 deletions src/spyglass/decoding/v1/waveform_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,12 @@ def make(self, key):
sorter,
)

spike_times = SpikeSortingOutput().fetch_nwb(merge_key)[0][
analysis_nwb_key
]["spike_times"]
nwb = SpikeSortingOutput().fetch_nwb(merge_key)[0]
spike_times = (
nwb[analysis_nwb_key]["spike_times"]
if analysis_nwb_key in nwb
else pd.DataFrame()
)

(
key["analysis_file_name"],
Expand Down Expand Up @@ -349,6 +352,8 @@ def _write_waveform_features_to_nwb(
metric_dict[unit_id] if unit_id in metric_dict else []
for unit_id in unit_ids
]
if not metric_values:
metric_values = np.array([]).astype(np.float32)
nwbf.add_unit_column(
name=metric,
description=metric,
Expand Down
6 changes: 3 additions & 3 deletions src/spyglass/spikesorting/spikesorting_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from ripple_detection import get_multiunit_population_firing_rate

from spyglass.spikesorting.imported import ImportedSpikeSorting # noqa: F401
from spyglass.spikesorting.v0.spikesorting_curation import (
from spyglass.spikesorting.v0.spikesorting_curation import ( # noqa: F401
CuratedSpikeSorting,
) # noqa: F401
)
from spyglass.spikesorting.v1 import ArtifactDetectionSelection # noqa: F401
from spyglass.spikesorting.v1 import (
CurationV1,
Expand Down Expand Up @@ -210,7 +210,7 @@ def get_spike_indicator(cls, key, time):
"""
time = np.asarray(time)
min_time, max_time = time[[0, -1]]
spike_times = cls.fetch_spike_data(key) # CB: This is undefined.
spike_times = (cls & key).get_spike_times(key)
spike_indicator = np.zeros((len(time), len(spike_times)))

for ind, times in enumerate(spike_times):
Expand Down

0 comments on commit d4dbc23

Please sign in to comment.