Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
khl02007 committed Jan 13, 2024
2 parents 3522e48 + 4dd1d8a commit b0ddff5
Show file tree
Hide file tree
Showing 12 changed files with 1,066 additions and 438 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
- Clean up following pre-commit checks. #688
- Add Mixin class to centralize `fetch_nwb` functionality. #692, #734
- Refactor restriction use in `delete_downstream_merge` #703
- Add `cautious_delete` to Mixin class, initial implementation. #711
- Add `cautious_delete` to Mixin class, initial implementation. #711, #762
- Add `deprecation_factory` to facilitate table migration. #717
- Add Spyglass logger. #730
- IntervalList: Add secondary key `pipeline` #742
Expand Down
509 changes: 327 additions & 182 deletions notebooks/20_Position_Trodes.ipynb

Large diffs are not rendered by default.

835 changes: 644 additions & 191 deletions notebooks/30_LFP.ipynb

Large diffs are not rendered by default.

37 changes: 20 additions & 17 deletions notebooks/py_scripts/20_Position_Trodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.0
# jupytext_version: 1.15.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
Expand Down Expand Up @@ -221,7 +221,19 @@

sgp.v1.TrodesPosV1 & trodes_key

# To retrieve the results as a pandas DataFrame with time as the index, we use `TrodesPosV1.fetch1_dataframe`.
# When we populatethe `TrodesPosV1` table, we automatically create an entry in the `PositionOutput` merge table.
# Since this table supports position information from multiple methods, it's best practive to access data through here.
#
# We can view the entry in this table:

# +
from spyglass.position import PositionOutput

PositionOutput.TrodesPosV1 & trodes_key
# -

# To retrieve the results as a pandas DataFrame with time as the index, we use `PositionOutput.fetch1_dataframe`.
# When doing so, we need to restric the merge table by the
#
# This dataframe has the following columns:
#
Expand All @@ -232,12 +244,10 @@
# - `speed`: the magnitude of the change in head position over time in cm/s
#

position_info = (
sgp.v1.TrodesPosV1()
& {
"nwb_file_name": nwb_copy_file_name,
}
).fetch1_dataframe()
# get the merge id corresponding to our inserted trodes_key
merge_key = (PositionOutput.merge_get_part(trodes_key)).fetch1("KEY")
# use this to restrict PositionOutput and fetch the data
position_info = (PositionOutput & merge_key).fetch1_dataframe()
position_info

# `.index` on the pandas dataframe gives us timestamps.
Expand Down Expand Up @@ -340,15 +350,8 @@
)
sgp.v1.TrodesPosV1.populate(trodes_s_up_key)

# +
upsampled_position_info = (
sgp.v1.TrodesPosV1()
& {
"nwb_file_name": nwb_copy_file_name,
"position_info_param_name": trodes_params_up_name,
}
).fetch1_dataframe()

merge_key = (PositionOutput.merge_get_part(trodes_s_up_key)).fetch1("KEY")
upsampled_position_info = (PositionOutput & merge_key).fetch1_dataframe()
upsampled_position_info

# +
Expand Down
9 changes: 5 additions & 4 deletions notebooks/py_scripts/30_LFP.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.0
# jupytext_version: 1.15.2
# kernelspec:
# display_name: Python 3.10.5 64-bit
# language: python
Expand Down Expand Up @@ -175,7 +175,8 @@
{
"target_interval_list_name": interval_list_name,
"filter_name": "LFP 0-400 Hz",
"filter_sampling_rate": 30_000,
"filter_sampling_rate": 30_000, # sampling rate of the data (Hz)
"target_sampling_rate": 1_000, # smpling rate of the lfp output (Hz)
}
)

Expand Down Expand Up @@ -271,7 +272,7 @@
lfp_band_key = (
lfp_band.LFPBandSelection
& {
"merge_id": lfp_key["merge_id"],
"lfp_merge_id": lfp_key["merge_id"],
"filter_name": filter_name,
"lfp_band_sampling_rate": lfp_band_sampling_rate,
}
Expand All @@ -284,7 +285,7 @@
lfp_band.LFPBandSelection() & lfp_band_key

lfp_band.LFPBandV1().populate(lfp_band.LFPBandSelection() & lfp_band_key)
lfp_band.LFPBandV1()
lfp_band.LFPBandV1() & lfp_band_key

# ## Plotting
#
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/common/common_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def get_djuser_name(cls, dj_user) -> str:
if len(query) != 1:
raise ValueError(
f"Could not find name for datajoint user {dj_user}"
+ f"in common.LabMember.LabMemberInfo: {query}"
+ f" in common.LabMember.LabMemberInfo: {query}"
)

return query[0]
Expand Down
5 changes: 4 additions & 1 deletion src/spyglass/lfp/v1/lfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def make(self, key):
"sampling_rate", "interval_list_name"
)
sampling_rate = int(np.round(sampling_rate))
target_sampling_rate = (LFPSelection & key).fetch1(
"target_sampling_rate"
)

# to get the list of valid times, we need to combine those from the user with those from the
# raw data
Expand All @@ -96,7 +99,7 @@ def make(self, key):
+ f"{MIN_LFP_INTERVAL_DURATION} sec long."
)
# target user-specified sampling rate
decimation = sampling_rate // key["target_sampling_rate"]
decimation = int(sampling_rate // target_sampling_rate)

# get the LFP filter that matches the raw data
filter = (
Expand Down
1 change: 1 addition & 0 deletions src/spyglass/spikesorting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .curation_figurl import CurationFigurl, CurationFigurlSelection
from .imported import ImportedSpikeSorting
from .sortingview import SortingviewWorkspace, SortingviewWorkspaceSelection
from .spikesorting_artifact import (
ArtifactDetection,
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/spikesorting/imported.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ImportedSpikeSorting(SpyglassMixin, dj.Imported):
definition = """
-> Session
---
object_id: varchar(32)
object_id: varchar(40)
"""

def make(self, key):
Expand Down
26 changes: 10 additions & 16 deletions src/spyglass/spikesorting/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,16 @@ class CuratedSpikeSorting(SpyglassMixin, dj.Part): # noqa: F811

def get_recording(cls, key):
"""get the recording associated with a spike sorting output"""
recording_key = cls.merge_restrict(key).proj()
query = (
source_class_dict[
to_camel_case(cls.merge_get_parent(key).table_name)
]
& recording_key
)
return query.get_recording(recording_key)
source_table = source_class_dict[
to_camel_case(cls.merge_get_parent(key).table_name)
]
query = source_table & cls.merge_get_part(key)
return query.get_recording(query.fetch("KEY"))

def get_sorting(cls, key):
"""get the sorting associated with a spike sorting output"""
sorting_key = cls.merge_restrict(key).proj()
query = (
source_class_dict[
to_camel_case(cls.merge_get_parent(key).table_name)
]
& sorting_key
)
return query.get_sorting(sorting_key)
source_table = source_class_dict[
to_camel_case(cls.merge_get_parent(key).table_name)
]
query = source_table & cls.merge_get_part(key)
return query.get_sorting(query.fetch("KEY"))
1 change: 1 addition & 0 deletions src/spyglass/spikesorting/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ..imported import ImportedSpikeSorting
from .artifact import (
ArtifactDetection,
ArtifactDetectionParameters,
Expand Down
75 changes: 51 additions & 24 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class SpyglassMixin:
_nwb_table_dict = {}
_delete_dependencies = []
_merge_delete_func = None
_session_pk = None
_member_pk = None

# ------------------------------- fetch_nwb -------------------------------

Expand Down Expand Up @@ -103,6 +105,8 @@ def _delete_deps(self) -> list:
from spyglass.common import LabMember, LabTeam, Session # noqa F401

self._delete_dependencies = [LabMember, LabTeam, Session]
self._session_pk = Session.primary_key[0]
self._member_pk = LabMember.primary_key[0]
return self._delete_dependencies

@property
Expand All @@ -119,10 +123,9 @@ def _merge_del_func(self) -> callable:
self._merge_delete_func = delete_downstream_merge
return self._merge_delete_func

def _find_session(
def _find_session_link(
self,
table: dj.user_tables.UserTable,
Session: dj.user_tables.UserTable,
search_limit: int = 2,
) -> dj.expression.QueryExpression:
"""Find Session table associated with table.
Expand All @@ -141,26 +144,47 @@ def _find_session(
datajoint.expression.QueryExpression or None
Join of table link with Session table if found, else None.
"""
Session = self._delete_deps[-1]
# TODO: check search_limit default is enough for any table in spyglass
if self.full_table_name == Session.full_table_name:
# if self is Session, return self
return self

elif (
# if Session is not in ancestors of table, search children
Session.full_table_name not in table.ancestors()
and search_limit > 0 # prevent infinite recursion
):
for child in table.children():
table = self._find_session(child, Session, search_limit - 1)
if self._session_pk in table.primary_key:
# joinable with Session
return table * Session

elif search_limit > 0:
for child in table.children(as_objects=True):
table = self._find_session_link(child, search_limit - 1)
if table: # table is link, will valid join to Session
break
return table

elif search_limit < 1: # if no session ancestor found and limit reached
elif not table or search_limit < 1: # if none found and limit reached
return # Err kept in parent func to centralize permission logic

return table * Session

def _get_exp_summary(self, sess_link: dj.expression.QueryExpression):
"""Get summary of experimenters for session(s), including NULL.
Parameters
----------
sess_link : datajoint.expression.QueryExpression
Join of table link with Session table.
Returns
-------
str
Summary of experimenters for session(s).
"""
Session = self._delete_deps[-1]

format = dj.U(self._session_pk, self._member_pk)
exp_missing = format & (sess_link - Session.Experimenter).proj(
**{self._member_pk: "NULL"}
)
exp_present = (
format & (sess_link * Session.Experimenter - exp_missing).proj()
)
return exp_missing + exp_present

def _check_delete_permission(self) -> None:
"""Check user name against lab team assoc. w/ self * Session.
Expand All @@ -181,32 +205,35 @@ def _check_delete_permission(self) -> None:
if dj_user in LabMember().admin: # bypass permission check for admin
return

sess = self._find_session(self, Session)
if not sess: # Permit delete if not linked to a session
sess_link = self._find_session_link(table=self)
if not sess_link: # Permit delete if not linked to a session
logger.warn(
"Could not find lab team associated with "
+ f"{self.__class__.__name__}."
+ "\nBe careful not to delete others' data."
)
return

experimenters = (sess * Session.Experimenter).fetch("lab_member_name")
if len(experimenters) < len(sess):
# TODO: adjust to check each session individually? Expensive but
# prevents against edge case of one sess with mult and another
# with none
sess_summary = self._get_exp_summary(
sess_link.restrict(self.restriction)
)
experimenters = sess_summary.fetch(self._member_pk)
if None in experimenters:
raise PermissionError(
f"Please ensure all Sessions have an experimenter:\n{sess}"
"Please ensure all Sessions have an experimenter in "
+ f"SessionExperimenter:\n{sess_summary}"
)

user_name = LabMember().get_djuser_name(dj_user)
for experimenter in set(experimenters):
if user_name not in LabTeam().get_team_members(experimenter):
sess_w_exp = sess_summary & {self._member_pk: experimenter}
raise PermissionError(
f"User '{user_name}' is not on a team with '{experimenter}'"
+ ", an experimenter for session(s):\n"
+ f"{sess * Session.Experimenter}"
+ f"{sess_w_exp}"
)
logger.info(f"Queueing delete for session(s):\n{sess_summary}")

# Rename to `delete` when we're ready to use it
# TODO: Intercept datajoint delete confirmation prompt for merge deletes
Expand Down

0 comments on commit b0ddff5

Please sign in to comment.