Skip to content

Commit

Permalink
Minor decoding fixes (LorenFrankLab#819)
Browse files Browse the repository at this point in the history
* Fix name and filtering

* Fix linting

* New methods for mua

* Fix name of part table

* Add detection interval and update notebooks

* Add figurl for ripple for debugging

* Update CHANGELOG.md

* Add a way to filter the channels

* Add z-score threshold and raster

* Add demo and paper code references

* Fix ripple trace

* Fix ripple and mua z-score

* Remove raster from mua figurl, update notebooks

* Go back to red color

* Remove because doesn't belong here

* Apply suggestions from code review

Co-authored-by: Chris Brozdowski <[email protected]>

* Handle no zscore case

* Move ripple times inside function

* Simplify getting the time and speed

* Update notebooks

---------

Co-authored-by: Chris Brozdowski <[email protected]>
  • Loading branch information
edeno and CBroz1 authored Feb 9, 2024
1 parent 992de69 commit ffe88c3
Show file tree
Hide file tree
Showing 11 changed files with 781 additions and 2,538 deletions.
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
- Spike sorting:
- Add SpikeSorting V1 pipeline. #651
- Move modules into spikesorting.v0 #807
- Add MUA analysis to spike sorting pipeline
- LFP:
- Minor fixes to LFPBandV1 populator and `make`. #706, #795
- LFPV1: Fix error for multiple lfp settings on same data #775
Expand All @@ -41,7 +40,7 @@
- DLC path handling from config, and normalize naming convention. #722
- Fix in place column bug #752
- Decoding:
- Add `decoding` pipeline V1. #731, #769
- Add `decoding` pipeline V1. #731, #769, #819
- Add a table to store the decoding results #731
- Use the new `non_local_detector` package for decoding #731
- Allow multiple spike waveform features for clusterless decoding #731
Expand All @@ -51,6 +50,10 @@
- Rename SortedSpikesGroup.SortGroup to SortedSpikesGroup.Units #807
- Change methods with load\_... to fetch\_... for consistency #807
- Use merge table methods to access part methods #807
- MUA
- Add MUA pipeline V1. #731, #819
- Ripple
- Add figurl to Ripple pipeline #819

## [0.4.3] (November 7, 2023)

Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ visualization, and sharing of neuroscience data to support reproducible
research. It is designed to be interoperable with the NWB format and integrates
open-source tools into a coherent framework.

Try out a demo [here](https://spyglass.hhmi.2i2c.cloud/hub/user-redirect/git-pull?repo=https%3A%2F%2Fgithub.com%2FLorenFrankLab%2Fspyglass-demo&urlpath=lab%2Ftree%2Fspyglass-demo%2Fnotebooks%2F01_Insert_Data.ipynb&branch=main)!

Features of Spyglass include:

- **Standardized data storage** - Spyglass uses the open-source
Expand Down Expand Up @@ -86,3 +88,5 @@ a data analysis framework for reproducible and shareable neuroscience research.
[10.1101/2024.01.25.577295](https://doi.org/10.1101/2024.01.25.577295 ).

*\* Equal contribution*

See paper related code [here](https://github.com/LorenFrankLab/spyglass-paper).
2 changes: 1 addition & 1 deletion notebooks/43_Decoding_SortedSpikes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@
],
"source": [
"# look at the sorting within the group we just made\n",
"SortedSpikesGroup.SortGroup & {\n",
"SortedSpikesGroup.Units & {\n",
" \"nwb_file_name\": nwb_copy_file_name,\n",
" \"sorted_spikes_group_name\": \"test_group\",\n",
" \"unit_filter_params_name\": unit_filter_params_name,\n",
Expand Down
2,766 changes: 457 additions & 2,309 deletions notebooks/51_MUA_Detection.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/py_scripts/43_Decoding_SortedSpikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
# -

# look at the sorting within the group we just made
SortedSpikesGroup.SortGroup & {
SortedSpikesGroup.Units & {
"nwb_file_name": nwb_copy_file_name,
"sorted_spikes_group_name": "test_group",
"unit_filter_params_name": unit_filter_params_name,
Expand Down
224 changes: 47 additions & 177 deletions notebooks/py_scripts/51_MUA_Detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,229 +13,99 @@
# ---

# +
from pathlib import Path
import datajoint as dj
from pathlib import Path

dj.config.load(
Path("../dj_local_conf.json").absolute()
) # load config for database connection info

from spyglass.mua.v1.mua import MuaEventsV1, MuaEventsParameters

# -

# # MUA Analysis and Detection
#
# NOTE: This notebook is a work in progress. It is not yet complete and may contain errors.
MuaEventsParameters()

MuaEventsV1()

# +
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput
import spyglass.spikesorting.v1 as sgs

from spyglass.position import PositionOutput

nwb_copy_file_name = "mediumnwb20230802_.nwb"

sorter_keys = {
trodes_s_key = {
"nwb_file_name": nwb_copy_file_name,
"sorter": "clusterless_thresholder",
"sorter_param_name": "default_clusterless",
"interval_list_name": "pos 0 valid times",
"trodes_pos_params_name": "single_led_upsampled",
}

(sgs.SpikeSortingSelection & sorter_keys) * SpikeSortingOutput.CurationV1

# +
spikesorting_merge_ids = (
(sgs.SpikeSortingSelection & sorter_keys) * SpikeSortingOutput.CurationV1
).fetch("merge_id")

spikesorting_merge_ids

# +
from spyglass.spikesorting.unit_inclusion_merge import (
ImportedUnitInclusionV1,
UnitInclusionOutput,
)

ImportedUnitInclusionV1().insert_all_units(spikesorting_merge_ids)

UnitInclusionOutput.ImportedUnitInclusionV1() & [
{"spikesorting_merge_id": id} for id in spikesorting_merge_ids
]

# +
from spyglass.spikesorting.unit_inclusion_merge import (
ImportedUnitInclusionV1,
UnitInclusionOutput,
)

ImportedUnitInclusionV1().insert_all_units(spikesorting_merge_ids)

UnitInclusionOutput.ImportedUnitInclusionV1() & [
{"spikesorting_merge_id": id} for id in spikesorting_merge_ids
]
pos_merge_id = (PositionOutput.TrodesPosV1 & trodes_s_key).fetch1("merge_id")
pos_merge_id

# +
from spyglass.spikesorting.unit_inclusion_merge import SortedSpikesGroup

unit_inclusion_merge_ids = (
UnitInclusionOutput.ImportedUnitInclusionV1
& [{"spikesorting_merge_id": id} for id in spikesorting_merge_ids]
).fetch("merge_id")

SortedSpikesGroup().create_group(
group_name="test_group",
nwb_file_name=nwb_copy_file_name,
unit_inclusion_merge_ids=unit_inclusion_merge_ids,
from spyglass.spikesorting.analysis.v1.group import (
SortedSpikesGroup,
)

group_key = {
sorted_spikes_group_key = {
"nwb_file_name": nwb_copy_file_name,
"sorted_spikes_group_name": "test_group",
"unit_filter_params_name": "default_exclusion",
}

SortedSpikesGroup & group_key
# -

SortedSpikesGroup.Units() & group_key

# An example of how to get spike times

spike_times = SortedSpikesGroup.fetch_spike_data(group_key)
spike_times[0]
SortedSpikesGroup & sorted_spikes_group_key

# +
from spyglass.position import PositionOutput

position_merge_id = (
PositionOutput.TrodesPosV1
& {
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 0 valid times",
"trodes_pos_params_name": "default_decoding",
}
).fetch1("merge_id")

position_info = (
(PositionOutput & {"merge_id": position_merge_id})
.fetch1_dataframe()
.dropna()
)
position_info

# +
time_ind_slice = slice(63_000, 70_000)
time = position_info.index[time_ind_slice]

SortedSpikesGroup.get_spike_indicator(group_key, time)

# +
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 1, sharex=True, figsize=(15, 4))
multiunit_firing_rate = SortedSpikesGroup.get_firing_rate(
group_key, time, multiunit=True
)
axes[0].plot(
time,
multiunit_firing_rate,
)
axes[0].set_ylabel("firing rate (Hz)")
axes[0].set_title("multiunit")
axes[1].fill_between(
time, position_info["speed"].iloc[time_ind_slice], color="lightgrey"
)
axes[1].set_ylabel("speed (cm/s)")
axes[1].set_xlabel("time (s)")

# +
from spyglass.mua.v1.mua import MuaEventsParameters, MuaEventsV1

MuaEventsParameters().insert_default()
MuaEventsParameters()

# +
selection_key = {
mua_key = {
"mua_param_name": "default",
"nwb_file_name": nwb_copy_file_name,
"sorted_spikes_group_name": "test_group",
"pos_merge_id": position_merge_id,
"artifact_interval_list_name": "test_artifact_times",
**sorted_spikes_group_key,
"pos_merge_id": pos_merge_id,
"detection_interval": "pos 0 valid times",
}

MuaEventsV1.populate(selection_key)
MuaEventsV1().populate(mua_key)
MuaEventsV1 & mua_key
# -

MuaEventsV1 & selection_key

mua_times = (MuaEventsV1 & selection_key).fetch1_dataframe()
mua_times = (MuaEventsV1 & mua_key).fetch1_dataframe()
mua_times

# +
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(2, 1, sharex=True, figsize=(15, 4))
speed = MuaEventsV1.get_speed(mua_key).to_numpy()
time = speed.index.to_numpy()
multiunit_firing_rate = MuaEventsV1.get_firing_rate(mua_key, time)

time_slice = slice(
np.searchsorted(time, mua_times.loc[10].start_time) - 1_000,
np.searchsorted(time, mua_times.loc[10].start_time) + 5_000,
)

axes[0].plot(
time,
multiunit_firing_rate,
time[time_slice],
multiunit_firing_rate[time_slice],
color="black",
)
axes[0].set_ylabel("firing rate (Hz)")
axes[0].set_title("multiunit")
axes[1].fill_between(
time, position_info["speed"].iloc[time_ind_slice], color="lightgrey"
)
axes[1].fill_between(time[time_slice], speed[time_slice], color="lightgrey")
axes[1].set_ylabel("speed (cm/s)")
axes[1].set_xlabel("time (s)")

in_bounds = np.logical_and(
mua_times.start_time >= time[0], mua_times.end_time <= time[-1]
)

for mua_time in mua_times.loc[in_bounds].itertuples():
axes[0].axvspan(
mua_time.start_time, mua_time.end_time, color="red", alpha=0.3
for id, mua_time in mua_times.loc[
np.logical_and(
mua_times["start_time"] > time[time_slice].min(),
mua_times["end_time"] < time[time_slice].max(),
)
axes[1].axvspan(
mua_time.start_time, mua_time.end_time, color="red", alpha=0.3
].iterrows():
axes[0].axvspan(
mua_time["start_time"], mua_time["end_time"], color="red", alpha=0.5
)
axes[1].set_ylim((0, 80))
axes[1].axhline(4, color="black", linestyle="--")
axes[1].set_xlim((time[0], time[-1]))

# +
from spyglass.common import IntervalList

IntervalList() & {
"nwb_file_name": nwb_copy_file_name,
"pipeline": "spikesorting_artifact_v1",
}
# -

(
sgs.ArtifactDetectionParameters
* sgs.SpikeSortingRecording
* sgs.ArtifactDetectionSelection
)

SpikeSortingOutput.CurationV1() * (
sgs.ArtifactDetectionParameters
* sgs.SpikeSortingRecording
* sgs.ArtifactDetectionSelection
)

(
IntervalList()
& {
"nwb_file_name": nwb_copy_file_name,
"pipeline": "spikesorting_artifact_v1",
}
).proj(artifact_id="interval_list_name")

sgs.SpikeSortingRecording() * sgs.ArtifactDetectionSelection()

SpikeSortingOutput.CurationV1() * sgs.SpikeSortingRecording()

IntervalList.insert1(
{
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "test_artifact_times",
"valid_times": [],
}
(MuaEventsV1 & mua_key).create_figurl(
zscore_mua=True,
)
1 change: 1 addition & 0 deletions src/spyglass/mua/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from spyglass.mua.v1.mua import MuaEventsParameters, MuaEventsV1 # noqa: F401
1 change: 1 addition & 0 deletions src/spyglass/mua/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from spyglass.mua.v1.mua import MuaEventsParameters, MuaEventsV1 # noqa: F401
Loading

0 comments on commit ffe88c3

Please sign in to comment.