From a53b50a3632aeccf1323c24e8b88b4f1f8738a17 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 8 Aug 2023 17:45:21 -0500 Subject: [PATCH] Refactor common --- src/spyglass/common/common_behav.py | 92 ++- src/spyglass/common/common_device.py | 401 ++++++----- src/spyglass/common/common_dio.py | 11 +- src/spyglass/common/common_ephys.py | 646 +++++++++--------- src/spyglass/common/common_filter.py | 529 +++++++------- src/spyglass/common/common_interval.py | 270 ++++---- src/spyglass/common/common_lab.py | 27 +- src/spyglass/common/common_nwbfile.py | 259 +++---- src/spyglass/common/common_position.py | 104 +-- src/spyglass/common/common_region.py | 26 +- .../spikesorting/spikesorting_recording.py | 92 +-- src/spyglass/utils/nwb_helper_fn.py | 35 +- 12 files changed, 1319 insertions(+), 1173 deletions(-) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index e00418b1d..5bf900532 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -7,6 +7,7 @@ import pandas as pd import pynwb +from ..settings import config from ..utils.dj_helper_fn import fetch_nwb from ..utils.nwb_helper_fn import ( get_all_spatial_series, @@ -48,23 +49,24 @@ def insert_from_nwbfile(cls, nwb_file_name): pos_dict = get_all_spatial_series(nwbf, verbose=True) if pos_dict is not None: - for epoch in pos_dict: - pdict = pos_dict[epoch] + for epoch, pdict in pos_dict.items(): pos_interval_list_name = cls.get_pos_interval_name(epoch) # create the interval list and insert it - interval_dict = dict() - interval_dict["nwb_file_name"] = nwb_file_name - interval_dict["interval_list_name"] = pos_interval_list_name - interval_dict["valid_times"] = pdict["valid_times"] + interval_dict = { + "nwb_file_name": nwb_file_name, + "interval_list_name": pos_interval_list_name, + "valid_times": pdict["valid_times"], + } IntervalList().insert1(interval_dict, skip_duplicates=True) # add this interval list to the table - key = dict() - key["nwb_file_name"] = nwb_file_name - key["interval_list_name"] = pos_interval_list_name - key["source"] = "trodes" - key["import_file_name"] = "" + key = { + "nwb_file_name": nwb_file_name, + "interval_list_name": pos_interval_list_name, + "source": "trodes", + "import_file_name": "", + } cls.insert1(key) @staticmethod @@ -95,16 +97,14 @@ def make(self, key): nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) - # TODO refactor this. this calculates sampling rate (unused here) and is expensive to do twice + pos_interval_name = PositionSource.get_pos_interval_name( + key["interval_list_name"] + ) pos_dict = get_all_spatial_series(nwbf) - for epoch in pos_dict: - if key[ - "interval_list_name" - ] == PositionSource.get_pos_interval_name(epoch): - pdict = pos_dict[epoch] - key["raw_position_object_id"] = pdict["raw_position_object_id"] - self.insert1(key) - break + if pos_interval_name in pos_dict: + pdict = pos_dict[pos_interval_name] + key["raw_position_object_id"] = pdict["raw_position_object_id"] + self.insert1(key) def fetch_nwb(self, *attrs, **kwargs): return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs) @@ -157,12 +157,9 @@ def make(self, key): epoch_list = associated_file_obj.task_epochs.split(",") # only insert if this is the statescript file print(associated_file_obj.description) - if ( - "statescript".upper() in associated_file_obj.description.upper() - or "state_script".upper() - in associated_file_obj.description.upper() - or "state script".upper() - in associated_file_obj.description.upper() + if any( + word.upper() in associated_file_obj.description.upper() + for word in ["statescript", "state_script", "state script"] ): # find the file associated with this epoch if str(key["epoch"]) in epoch_list: @@ -205,8 +202,8 @@ def make(self, key): if videos is None: print(f"No video data interface found in {nwb_file_name}\n") return - else: - videos = videos.time_series + + videos = videos.time_series # get the interval for the current TaskEpoch interval_list_name = (TaskEpoch() & key).fetch1("interval_list_name") @@ -261,38 +258,35 @@ def update_entries(cls, restrict={}): cls.update1(row=row) @classmethod - def get_abs_path(cls, key: Dict): - """Return the absolute path for a stored video file given a key with the nwb_file_name and epoch number + def get_abs_path(cls, key: dict) -> str: + """Return absolute path for a stored video in SPYGLASS_VIDEO_DIR - The SPYGLASS_VIDEO_DIR environment variable must be set. + file given a key with the + nwb_file_name and epoch number Parameters ---------- key : dict - dictionary with nwb_file_name and epoch as keys + dict w/nwb_file_name, referring to an entry with video_file_object_id Returns ------- nwb_video_file_abspath : str The absolute path for the given file name. """ - video_dir = pathlib.Path(os.getenv("SPYGLASS_VIDEO_DIR", None)) - assert video_dir is not None, "You must set SPYGLASS_VIDEO_DIR" - if not video_dir.exists(): - raise OSError("SPYGLASS_VIDEO_DIR does not exist") - video_info = (cls & key).fetch1() - nwb_path = Nwbfile.get_abs_path(key["nwb_file_name"]) - nwbf = get_nwb_file(nwb_path) - nwb_video = nwbf.objects[video_info["video_file_object_id"]] - video_filename = nwb_video.name - # see if the file exists and is stored in the base analysis dir - nwb_video_file_abspath = pathlib.Path( - f"{video_dir}/{pathlib.Path(video_filename)}" + video_filename = ( + get_nwb_file(Nwbfile.get_abs_path(key["nwb_file_name"])) + .objects[(cls & key).fetch1("video_file_object_id")] + .name + ) + + nwb_video_file_abspath = ( + pathlib.Path(config["SPYGLASS_VIDEO_DIR"]) / video_filename ) + if nwb_video_file_abspath.exists(): return nwb_video_file_abspath.as_posix() - else: - raise FileNotFoundError( - f"video file with filename: {video_filename} " - f"does not exist in {video_dir}/" - ) + + raise FileNotFoundError( + f"Video file does not exist: {nwb_video_file_abspath}" + ) diff --git a/src/spyglass/common/common_device.py b/src/spyglass/common/common_device.py index 7849e6099..e05be3fcf 100644 --- a/src/spyglass/common/common_device.py +++ b/src/spyglass/common/common_device.py @@ -40,31 +40,35 @@ class DataAcquisitionDevice(dj.Manual): def insert_from_nwbfile(cls, nwbf, config): """Insert data acquisition devices from an NWB file. - Note that this does not link the DataAcquisitionDevices with a Session. For that, - see DataAcquisitionDeviceList. + Note that this does not link the DataAcquisitionDevices with a Session. + For that, see DataAcquisitionDeviceList. Parameters ---------- nwbf : pynwb.NWBFile The source NWB file object. config : dict - Dictionary read from a user-defined YAML file containing values to replace in the NWB file. + Dictionary read from a user-defined YAML file containing values to + replace in the NWB file. """ _, ndx_devices, _ = cls.get_all_device_names(nwbf, config) for device_name in ndx_devices: new_device_dict = dict() - # read device properties into new_device_dict from PyNWB extension device object + # read device properties into new_device_dict from PyNWB extension + # device object nwb_device_obj = ndx_devices[device_name] name = nwb_device_obj.name adc_circuit = nwb_device_obj.adc_circuit - # transform system value. check if value is in DB. if not, prompt user to add an entry or cancel. + # transform system value. check if value is in DB. if not, prompt + # user to add an entry or cancel. system = cls._add_system(nwb_device_obj.system) - # transform amplifier value. check if value is in DB. if not, prompt user to add an entry or cancel. + # transform amplifier value. check if value is in DB. if not, prompt + # user to add an entry or cancel. amplifier = cls._add_amplifier(nwb_device_obj.amplifier) # standardize how Intan is represented in the database @@ -80,29 +84,34 @@ def insert_from_nwbfile(cls, nwbf, config): if ndx_devices: print( - f"Inserted or referenced data acquisition device(s): {ndx_devices.keys()}" + "Inserted or referenced data acquisition device(s): " + + f"{ndx_devices.keys()}" ) else: print("No conforming data acquisition device metadata found.") @classmethod - def get_all_device_names(cls, nwbf, config): - """Get a list of all device names in the NWB file, after appending and overwriting by the config file. + def get_all_device_names(cls, nwbf, config) -> tuple: + """ + Return device names in the NWB file, after appending and overwriting by + the config file. Parameters ---------- nwbf : pynwb.NWBFile The source NWB file object. config : dict - Dictionary read from a user-defined YAML file containing values to replace in the NWB file. + Dictionary read from a user-defined YAML file containing values to + replace in the NWB file. Returns ------- - device_name_list : list + device_name_list : tuple List of data acquisition object names found in the NWB file. """ - # make a dict mapping device name to PyNWB device object for all devices in the NWB file that are - # of type ndx_franklab_novela.DataAcqDevice and thus have the required metadata + # make a dict mapping device name to PyNWB device object for all devices + # in the NWB file that are of type ndx_franklab_novela.DataAcqDevice and + # thus have the required metadata ndx_devices = { device_obj.name: device_obj for device_obj in nwbf.devices.values() @@ -124,10 +133,11 @@ def get_all_device_names(cls, nwbf, config): @classmethod def _add_device(cls, new_device_dict): - """Check that the information in the NWB file and the database for the given device name match perfectly. + """Ensure match betweent NWB file info & database entry. - If no DataAcquisitionDevice with the given name exists in the database, check whether the user wants to add - a new entry instead of referencing an existing entry. If so, return. If not, raise an exception. + If no DataAcquisitionDevice with the given name exists in the database, + check whether the user wants to add a new entry instead of referencing + an existing entry. If so, return. If not, raise an exception. Parameters ---------- @@ -137,48 +147,48 @@ def _add_device(cls, new_device_dict): Raises ------ PopulateException - If user chooses not to add a device to the database when prompted or if the device properties from the - NWB file do not match the properties of the corresponding database entry. + If user chooses not to add a device to the database when prompted or + if the device properties from the NWB file do not match the + properties of the corresponding database entry. """ name = new_device_dict["data_acquisition_device_name"] all_values = DataAcquisitionDevice.fetch( "data_acquisition_device_name" ).tolist() if name not in all_values: - # no entry with the same name exists, prompt user about adding a new entry + # no entry with the same name exists, prompt user to add a new entry print( - f"\nData acquisition device '{name}' was not found in the database. " - f"The current values are: {all_values}. " - "Please ensure that the device you want to add does not already " - "exist in the database under a different name or spelling. " + f"\nData acquisition device '{name}' was not found in the " + f"database. The current values are: {all_values}. " + "Please ensure that the device you want to add does not already" + " exist in the database under a different name or spelling. " "If you want to use an existing device in the database, " - "please change the corresponding Device object in the NWB file. " - "Entering 'N' will raise an exception." - ) - val = input( - f"Do you want to add data acquisition device '{name}' to the database? (y/N)" + "please change the corresponding Device object in the NWB file." + " Entering 'N' will raise an exception." ) + to_db = " to the database" + val = input(f"Add data acquisition device '{name}'{to_db}? (y/N)") if val.lower() in ["y", "yes"]: cls.insert1(new_device_dict, skip_duplicates=True) return raise PopulateException( - f"User chose not to add data acquisition device '{name}' to the database." + f"User chose not to add device '{name}'{to_db}." ) - # effectively else (entry exists) - # check whether the values provided match the values stored in the database + # Check if values provided match the values stored in the database db_dict = ( DataAcquisitionDevice & {"data_acquisition_device_name": name} ).fetch1() if db_dict != new_device_dict: raise PopulateException( - f"Data acquisition device properties of PyNWB Device object with name '{name}': " - f"{new_device_dict} do not match properties of the corresponding database entry: {db_dict}." + "Data acquisition device properties of PyNWB Device object " + + f"with name '{name}': {new_device_dict} do not match " + f"properties of the corresponding database entry: {db_dict}." ) @classmethod def _add_system(cls, system): - """Check the system value. If it is not in the database, prompt the user to add the value to the database. + """Check the system value. If not in the db, prompt user to add it. This method also renames the system value "MCU" to "SpikeGadgets". @@ -190,7 +200,8 @@ def _add_system(cls, system): Raises ------ PopulateException - If user chooses not to add a device system value to the database when prompted. + If user chooses not to add a device system value to the database + when prompted. Returns ------- @@ -205,29 +216,31 @@ def _add_system(cls, system): ).tolist() if system not in all_values: print( - f"\nData acquisition device system '{system}' was not found in the database. " - f"The current values are: {all_values}. " - "Please ensure that the system you want to add does not already " - "exist in the database under a different name or spelling. " + f"\nData acquisition device system '{system}' was not found in" + f" the database. The current values are: {all_values}. " + "Please ensure that the system you want to add does not already" + " exist in the database under a different name or spelling. " "If you want to use an existing system in the database, " - "please change the corresponding Device object in the NWB file. " - "Entering 'N' will raise an exception." + "please change the corresponding Device object in the NWB file." + " Entering 'N' will raise an exception." ) val = input( - f"Do you want to add data acquisition device system '{system}' to the database? (y/N)" + f"Do you want to add data acquisition device system '{system}'" + + " to the database? (y/N)" ) if val.lower() in ["y", "yes"]: key = {"data_acquisition_device_system": system} DataAcquisitionDeviceSystem.insert1(key, skip_duplicates=True) else: raise PopulateException( - f"User chose not to add data acquisition device system '{system}' to the database." + "User chose not to add data acquisition device system " + + f"'{system}' to the database." ) return system @classmethod def _add_amplifier(cls, amplifier): - """Check the amplifier value. If it is not in the database, prompt the user to add the value to the database. + """Check amplifier value. If not in db, prompt user to add. Parameters ---------- @@ -237,7 +250,8 @@ def _add_amplifier(cls, amplifier): Raises ------ PopulateException - If user chooses not to add a device amplifier value to the database when prompted. + If user chooses not to add a device amplifier value to the database + when prompted. Returns ------- @@ -253,16 +267,17 @@ def _add_amplifier(cls, amplifier): ).tolist() if amplifier not in all_values: print( - f"\nData acquisition device amplifier '{amplifier}' was not found in the database. " - f"The current values are: {all_values}. " - "Please ensure that the amplifier you want to add does not already " - "exist in the database under a different name or spelling. " - "If you want to use an existing name in the database, " - "please change the corresponding Device object in the NWB file. " - "Entering 'N' will raise an exception." + f"\nData acquisition device amplifier '{amplifier}' was not " + f"found in the database. The current values are: {all_values}. " + "Please ensure that the amplifier you want to add does not " + "already exist in the database under a different name or " + "spelling. If you want to use an existing name in the database," + " please change the corresponding Device object in the NWB " + "file. Entering 'N' will raise an exception." ) val = input( - f"Do you want to add data acquisition device amplifier '{amplifier}' to the database? (y/N)" + "Do you want to add data acquisition device amplifier " + + f"'{amplifier}' to the database? (y/N)" ) if val.lower() in ["y", "yes"]: key = {"data_acquisition_device_amplifier": amplifier} @@ -271,7 +286,8 @@ def _add_amplifier(cls, amplifier): ) else: raise PopulateException( - f"User chose not to add data acquisition device amplifier '{amplifier}' to the database." + "User chose not to add data acquisition device amplifier " + + f"'{amplifier}' to the database." ) return amplifier @@ -306,14 +322,17 @@ def insert_from_nwbfile(cls, nwbf): for device in nwbf.devices.values(): if isinstance(device, ndx_franklab_novela.CameraDevice): device_dict = dict() - # TODO ideally the ID is not encoded in the name formatted in a particular way - # device.name must have the form "[any string without a space, usually camera] [int]" - device_dict["camera_id"] = int(str.split(device.name)[1]) - device_dict["camera_name"] = device.camera_name - device_dict["manufacturer"] = device.manufacturer - device_dict["model"] = device.model - device_dict["lens"] = device.lens - device_dict["meters_per_pixel"] = device.meters_per_pixel + # TODO ideally the ID is not encoded in the name formatted in a + # particular way device.name must have the form "[any string + # without a space, usually camera] [int]" + device_dict = { + "camera_id": int(device.name.split()[1]), + "camera_name": device.camera_name, + "manufacturer": device.manufacturer, + "model": device.model, + "lens": device.lens, + "meters_per_pixel": device.meters_per_pixel, + } cls.insert1(device_dict, skip_duplicates=True) device_name_list.append(device_dict["camera_name"]) if device_name_list: @@ -326,48 +345,50 @@ def insert_from_nwbfile(cls, nwbf): @schema class ProbeType(dj.Manual): definition = """ - # Type/category of probe, e.g., Neuropixels 1.0 or NeuroNexus X-Y-Z, regardless of configuration. - # This is a controlled vocabulary of probe type names. - # This is separated from Probe because probes like the Neuropixels 1.0 can have different dynamic configurations, - # e.g. channel maps. + # Type/category of probe regardless of configuration. Controlled vocabulary + # of probe type names. e.g., Neuropixels 1.0 or NeuroNexus X-Y-Z, etc. + # Separated from Probe because probes like the Neuropixels 1.0 can have + # different dynamic configurations e.g. channel maps. + probe_type: varchar(80) --- - probe_description: varchar(2000) # description of this probe - manufacturer = "": varchar(200) # manufacturer of this probe - num_shanks: int # number of shanks on this probe + probe_description: varchar(2000) # description of this probe + manufacturer = "": varchar(200) # manufacturer of this probe + num_shanks: int # number of shanks on this probe """ @schema class Probe(dj.Manual): definition = """ - # A configuration of a ProbeType. For most probe types, there is only one configuration, and that configuration - # should always be used. For Neuropixels probes, the specific channel map (which electrodes are used, - # where are they, and in what order) can differ between users and sessions, and each configuration should have a - # different ProbeType. - probe_id: varchar(80) # a unique ID for this probe and dynamic configuration + # A configuration of a ProbeType. For most probe types, there is only one, + # which should always be used. For Neuropixels, the channel map (which + # electrodes used, where they are, and in what order) can differ between + # users and sessions. Each config should have a different ProbeType. + probe_id: varchar(80) # a unique ID for this probe & dynamic config --- - -> ProbeType # the type of probe, selected from a controlled list of probe types - -> [nullable] DataAcquisitionDevice # the data acquisition device used with this Probe - contact_side_numbering: enum("True", "False") # if True, then electrode contacts are facing you when numbering them + -> ProbeType # Type of probe, selected from a controlled list + -> [nullable] DataAcquisitionDevice # the data actwquisition device used + contact_side_numbering: enum("True", "False") # Facing you when numbering """ class Shank(dj.Part): definition = """ -> Probe - probe_shank: int # shank number within probe. should be unique within a Probe + probe_shank: int # unique shank number within probe. """ class Electrode(dj.Part): definition = """ + # Electrode configuration, with ID, contact size, X/Y/Z coordinates -> Probe.Shank - probe_electrode: int # electrode ID that is output from the data acquisition system - # probe_electrode should be unique within a Probe + probe_electrode: int # electrode ID, output from acquisition + # system. Unique within a Probe --- contact_size = NULL: float # (um) contact size - rel_x = NULL: float # (um) x coordinate of the electrode within the probe - rel_y = NULL: float # (um) y coordinate of the electrode within the probe - rel_z = NULL: float # (um) z coordinate of the electrode within the probe + rel_x = NULL: float # (um) x coordinate of electrode + rel_y = NULL: float # (um) y coordinate of electrode + rel_z = NULL: float # (um) z coordinate of electrode """ @classmethod @@ -379,7 +400,8 @@ def insert_from_nwbfile(cls, nwbf, config): nwbf : pynwb.NWBFile The source NWB file object. config : dict - Dictionary read from a user-defined YAML file containing values to replace in the NWB file. + Dictionary read from a user-defined YAML file containing values to + replace in the NWB file. Returns ------- @@ -396,7 +418,8 @@ def insert_from_nwbfile(cls, nwbf, config): num_shanks = 0 if probe_type in ndx_probes: - # read probe properties into new_probe_dict from PyNWB extension probe object + # read probe properties into new_probe_dict from PyNWB extension + # probe object nwb_probe_obj = ndx_probes[probe_type] cls.__read_ndx_probe_data( nwb_probe_obj, @@ -412,13 +435,16 @@ def insert_from_nwbfile(cls, nwbf, config): shank_dict ), "`num_shanks` is not equal to the number of shanks." - # if probe id already exists, do not overwrite anything or create new Shanks and Electrodes - # TODO test whether the Shanks and Electrodes in the NWB file match the ones in the database + # if probe id already exists, do not overwrite anything or create + # new Shanks and Electrodes + # TODO: test whether the Shanks and Electrodes in the NWB file match + # the ones in the database query = Probe & {"probe_id": new_probe_dict["probe_id"]} if len(query) > 0: print( - f"Probe ID '{new_probe_dict['probe_id']}' already exists in the database. Spyglass will use " - "that and not create a new Probe, Shanks, or Electrodes." + f"Probe ID '{new_probe_dict['probe_id']}' already exists in" + " the database. Spyglass will use that and not create a new" + " Probe, Shanks, or Electrodes." ) continue @@ -438,14 +464,17 @@ def insert_from_nwbfile(cls, nwbf, config): @classmethod def get_all_probe_names(cls, nwbf, config): - """Get a list of all device names in the NWB file, after appending and overwriting by the config file. + """Get a list of all device names in the NWB. + + Includes all devices, after appending/overwriting by the config file. Parameters ---------- nwbf : pynwb.NWBFile The source NWB file object. config : dict - Dictionary read from a user-defined YAML file containing values to replace in the NWB file. + Dictionary read from a user-defined YAML file containing values to + replace in the NWB file. Returns ------- @@ -453,21 +482,22 @@ def get_all_probe_names(cls, nwbf, config): List of data acquisition object names found in the NWB file. """ - # make a dict mapping probe type to PyNWB object for all devices in the NWB file that are - # of type ndx_franklab_novela.Probe and thus have the required metadata + # make a dict mapping probe type to PyNWB object for all devices in the + # NWB file that are of type ndx_franklab_novela.Probe and thus have the + # required metadata ndx_probes = { device_obj.probe_type: device_obj for device_obj in nwbf.devices.values() if isinstance(device_obj, ndx_franklab_novela.Probe) } - # make a dict mapping probe type to dict of device metadata from the config YAML if exists - if "Probe" in config: - config_probes = [ - probe_dict["probe_type"] for probe_dict in config["Probe"] - ] - else: - config_probes = list() + # make a dict mapping probe type to dict of device metadata from the + # config YAML if exists + config_probes = ( + [probe_dict["probe_type"] for probe_dict in config["Probe"]] + if "Probe" in config + else list() + ) # get all the probe types from the NWB file plus the config YAML all_probes_types = set(ndx_probes.keys()).union(set(config_probes)) @@ -484,48 +514,46 @@ def __read_ndx_probe_data( elect_dict: dict, ): # construct dictionary of values to add to ProbeType - new_probe_type_dict["manufacturer"] = ( - getattr(nwb_probe_obj, "manufacturer") or "" + new_probe_type_dict.update( + { + "manufacturer": getattr(nwb_probe_obj, "manufacturer") or "", + "probe_type": nwb_probe_obj.probe_type, + "probe_description": nwb_probe_obj.probe_description, + "num_shanks": len(nwb_probe_obj.shanks), + } ) - new_probe_type_dict["probe_type"] = nwb_probe_obj.probe_type - new_probe_type_dict[ - "probe_description" - ] = nwb_probe_obj.probe_description - new_probe_type_dict["num_shanks"] = len(nwb_probe_obj.shanks) cls._add_probe_type(new_probe_type_dict) - new_probe_dict["probe_id"] = nwb_probe_obj.probe_type - new_probe_dict["probe_type"] = nwb_probe_obj.probe_type - new_probe_dict["contact_side_numbering"] = ( - "True" if nwb_probe_obj.contact_side_numbering else "False" + new_probe_dict.update( + { + "probe_id": nwb_probe_obj.probe_type, + "probe_type": nwb_probe_obj.probe_type, + "contact_side_numbering": "True" + if nwb_probe_obj.contact_side_numbering + else "False", + } ) - # go through the shanks and add each one to the Shank table for shank in nwb_probe_obj.shanks.values(): - shank_dict[shank.name] = dict() - shank_dict[shank.name]["probe_id"] = new_probe_dict["probe_type"] - shank_dict[shank.name]["probe_shank"] = int(shank.name) + shank_dict[shank.name] = { + "probe_id": new_probe_dict["probe_type"], + "probe_shank": int(shank.name), + } # go through the electrodes and add each one to the Electrode table for electrode in shank.shanks_electrodes.values(): - # the next line will need to be fixed if we have different sized contacts on a shank - elect_dict[electrode.name] = dict() - elect_dict[electrode.name]["probe_id"] = new_probe_dict[ - "probe_type" - ] - elect_dict[electrode.name]["probe_shank"] = shank_dict[ - shank.name - ]["probe_shank"] - elect_dict[electrode.name][ - "contact_size" - ] = nwb_probe_obj.contact_size - elect_dict[electrode.name]["probe_electrode"] = int( - electrode.name - ) - elect_dict[electrode.name]["rel_x"] = electrode.rel_x - elect_dict[electrode.name]["rel_y"] = electrode.rel_y - elect_dict[electrode.name]["rel_z"] = electrode.rel_z + # the next line will need to be fixed if we have different sized + # contacts on a shank + elect_dict[electrode.name] = { + "probe_id": new_probe_dict["probe_type"], + "probe_shank": shank_dict[shank.name]["probe_shank"], + "contact_size": nwb_probe_obj.contact_size, + "probe_electrode": int(electrode.name), + "rel_x": electrode.rel_x, + "rel_y": electrode.rel_y, + "rel_z": electrode.rel_z, + } @classmethod def _add_probe_type(cls, new_probe_type_dict): @@ -539,7 +567,8 @@ def _add_probe_type(cls, new_probe_type_dict): Raises ------ PopulateException - If user chooses not to add a probe type to the database when prompted. + If user chooses not to add a probe type to the database when + prompted. Returns ------- @@ -552,29 +581,32 @@ def _add_probe_type(cls, new_probe_type_dict): print( f"\nProbe type '{probe_type}' was not found in the database. " f"The current values are: {all_values}. " - "Please ensure that the probe type you want to add does not already " - "exist in the database under a different name or spelling. " - "If you want to use an existing name in the database, " - "please change the corresponding Probe object in the NWB file. " - "Entering 'N' will raise an exception." + "Please ensure that the probe type you want to add does not " + "already exist in the database under a different name or " + "spelling. If you want to use an existing name in the " + "database, please change the corresponding Probe object in the " + "NWB file. Entering 'N' will raise an exception." ) val = input( - f"Do you want to add probe type '{probe_type}' to the database? (y/N)" + f"Do you want to add probe type '{probe_type}' to the database?" + + " (y/N)" ) if val.lower() in ["y", "yes"]: ProbeType.insert1(new_probe_type_dict, skip_duplicates=True) return raise PopulateException( - f"User chose not to add probe type '{probe_type}' to the database." + f"User chose not to add probe type '{probe_type}' to the " + + "database." ) - # effectively else (entry exists) - # check whether the values provided match the values stored in the database + # else / entry exists: check whether the values provided match the + # values stored in the database db_dict = (ProbeType & {"probe_type": probe_type}).fetch1() if db_dict != new_probe_type_dict: raise PopulateException( - f"\nProbe type properties of PyNWB Probe object with name '{probe_type}': " - f"{new_probe_type_dict} do not match properties of the corresponding database entry: {db_dict}." + "\nProbe type properties of PyNWB Probe object with name " + f"'{probe_type}': {new_probe_type_dict} do not match properties" + f" of the corresponding database entry: {db_dict}." ) return probe_type @@ -587,23 +619,21 @@ def create_from_nwbfile( probe_type: str, contact_side_numbering: bool, ): - """Create a Probe entry and corresponding part table entries using the data in the NWB file. + """Create master/part Probe entry from the NWB file. - This method will parse the electrodes in the electrodes table, electrode groups (as shanks), and devices - (as probes) in the NWB file, but only ones that are associated with the device that matches the given + This method will parse the electrodes in the electrodes table, electrode + groups (as shanks), and devices (as probes) in the NWB file, but only + ones that are associated with the device that matches the given `nwb_device_name`. - Note that this code assumes the relatively standard convention where the NWB device corresponds to a Probe, - the NWB electrode group corresponds to a Shank, and the NWB electrode corresponds to an Electrode. + Note that this code assumes the relatively standard convention where the + NWB device corresponds to a Probe, the NWB electrode group corresponds + to a Shank, and the NWB electrode corresponds to an Electrode. - Example usage: - ``` - sgc.Probe.create_from_nwbfile( - nwbfile=nwb_file_name, - nwb_device_name="Device", + Example usage: ``` sgc.Probe.create_from_nwbfile( + nwbfile=nwb_file_name, nwb_device_name="Device", probe_id="Neuropixels 1.0 Giocomo Lab Configuration", - probe_type="Neuropixels 1.0", - contact_side_numbering=True + probe_type="Neuropixels 1.0", contact_side_numbering=True ) ``` @@ -612,13 +642,17 @@ def create_from_nwbfile( nwb_file_name : str The name of the NWB file. nwb_device_name : str - The name of the PyNWB Device object that represents the probe to read in the NWB file. + The name of the PyNWB Device object that represents the probe to + read in the NWB file. probe_id : str - A unique ID for the probe and its configuration, to be used as the primary key for the new Probe entry. + A unique ID for the probe and its configuration, to be used as the + primary key for the new Probe entry. probe_type : str - The existing ProbeType entry that represents the type of probe being created. It must exist. + The existing ProbeType entry that represents the type of probe being + created. It must exist. contact_side_numbering : bool - Whether the electrode contacts are facing you when numbering them. Stored in the new Probe entry. + Whether the electrode contacts are facing you when numbering them. + Stored in the new Probe entry. """ from .common_nwbfile import Nwbfile @@ -633,49 +667,51 @@ def create_from_nwbfile( ) return - new_probe_dict = dict() - shank_dict = dict() - elect_dict = dict() - - new_probe_dict["probe_id"] = probe_id - new_probe_dict["probe_type"] = probe_type - new_probe_dict["contact_side_numbering"] = ( - "True" if contact_side_numbering else "False" - ) + new_probe_dict = { + "probe_id": probe_id, + "probe_type": probe_type, + "contact_side_numbering": ( + "True" if contact_side_numbering else "False" + ), + } + shank_dict = {} + elect_dict = {} # iterate through the electrodes table in the NWB file # and use the group column (ElectrodeGroup) to create shanks # and use the device attribute of each ElectrodeGroup to create a probe - created_shanks = dict() # map device name to shank_index (int) + created_shanks = {} # map device name to shank_index (int) device_found = False for elec_index in range(len(nwbfile.electrodes)): electrode_group = nwbfile.electrodes[elec_index, "group"] eg_device_name = electrode_group.device.name - # only look at electrodes where the associated device is the one specified + # only look at electrodes where the associated device is the one + # specified if eg_device_name == nwb_device_name: device_found = True - # if a Shank has not yet been created from the electrode group, then create it + # if a Shank has not yet been created from the electrode group, + # then create it if electrode_group.name not in created_shanks: shank_index = len(created_shanks) created_shanks[electrode_group.name] = shank_index # build the dictionary of Probe.Shank data - shank_dict[shank_index] = dict() - shank_dict[shank_index]["probe_id"] = new_probe_dict[ - "probe_id" - ] - shank_dict[shank_index]["probe_shank"] = shank_index + shank_dict[shank_index] = { + "probe_id": new_probe_dict["probe_id"], + "probe_shank": shank_index, + } # get the probe shank index associated with this Electrode probe_shank = created_shanks[electrode_group.name] # build the dictionary of Probe.Electrode data - elect_dict[elec_index] = dict() - elect_dict[elec_index]["probe_id"] = new_probe_dict["probe_id"] - elect_dict[elec_index]["probe_shank"] = probe_shank - elect_dict[elec_index]["probe_electrode"] = elec_index + elect_dict[elec_index] = { + "probe_id": new_probe_dict["probe_id"], + "probe_shank": probe_shank, + "probe_electrode": elec_index, + } if "rel_x" in nwbfile.electrodes[elec_index]: elect_dict[elec_index]["rel_x"] = nwbfile.electrodes[ elec_index, "rel_x" @@ -691,7 +727,8 @@ def create_from_nwbfile( if not device_found: print( - f"No electrodes in the NWB file were associated with a device named '{nwb_device_name}'." + "No electrodes in the NWB file were associated with a device " + + f"named '{nwb_device_name}'." ) return diff --git a/src/spyglass/common/common_dio.py b/src/spyglass/common/common_dio.py index f5ad37de1..e8c12c662 100644 --- a/src/spyglass/common/common_dio.py +++ b/src/spyglass/common/common_dio.py @@ -34,11 +34,12 @@ def make(self, key): ) if behav_events is None: print( - f"No conforming behavioral events data interface found in {nwb_file_name}\n" + "No conforming behavioral events data interface found in " + + f"{nwb_file_name}\n" ) return - # the times for these events correspond to the valid times for the raw data + # Times for these events correspond to the valid times for the raw data key["interval_list_name"] = ( Raw() & {"nwb_file_name": nwb_file_name} ).fetch1("interval_list_name") @@ -55,8 +56,10 @@ def plot_all_dio_events(self): Examples -------- - > (DIOEvents & {'nwb_file_name': 'arthur20220314_.nwb'}).plot_all_dio_events() - > (DIOEvents & [{'nwb_file_name': "arthur20220314_.nwb"}, {"nwb_file_name": "arthur20220316_.nwb"}]).plot_all_dio_events() + > restr1 = {'nwb_file_name': 'arthur20220314_.nwb'} + > restr2 = {'nwb_file_name': 'arthur20220316_.nwb'} + > (DIOEvents & restr1).plot_all_dio_events() + > (DIOEvents & [restr1, restr2]).plot_all_dio_events() """ behavioral_events = self.fetch_nwb() diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index 5ca0af03c..81b855fba 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -49,7 +49,7 @@ def make(self, key): nwbf = get_nwb_file(nwb_file_abspath) for electrode_group in nwbf.electrode_groups.values(): key["electrode_group_name"] = electrode_group.name - # add electrode group location if it does not exist, and fetch the row + # add electrode group location if not exist, and fetch the row key["region_id"] = BrainRegion.fetch_add( region_name=electrode_group.location ) @@ -74,23 +74,25 @@ def make(self, key): @schema class Electrode(dj.Imported): definition = """ + # Electrode configuration, with ID, local and warped X/Y/Z coordinates -> ElectrodeGroup - electrode_id: int # the unique number for this electrode + electrode_id: int # Unique electrode number --- -> [nullable] Probe.Electrode -> BrainRegion name = "": varchar(200) # unique label for each contact - original_reference_electrode = -1: int # the configured reference electrode for this electrode - x = NULL: float # the x coordinate of the electrode position in the brain - y = NULL: float # the y coordinate of the electrode position in the brain - z = NULL: float # the z coordinate of the electrode position in the brain - filtering: varchar(2000) # description of the signal filtering + original_reference_electrode = -1: int # configured reference electrode + x = NULL: float # x coordinate in the brain + y = NULL: float # y coordinate in the brain + z = NULL: float # z coordinate in the brain + filtering: varchar(2000) # descript of the signal filtering impedance = NULL: float # electrode impedance - bad_channel = "False": enum("True", "False") # if electrode is "good" or "bad" as observed during recording - x_warped = NULL: float # x coordinate of electrode position warped to common template brain - y_warped = NULL: float # y coordinate of electrode position warped to common template brain - z_warped = NULL: float # z coordinate of electrode position warped to common template brain - contacts: varchar(200) # label of electrode contacts used for a bipolar signal - current workaround + bad_channel = "False": enum("True", "False") # as observed during recording + x_warped = NULL: float # x coordinate warped to template + y_warped = NULL: float # y coordinate warped to template + z_warped = NULL: float # z coordinate warped to template + contacts: varchar(200) # label of contacts used for bipolar + # signal - current workaround """ def make(self, key): @@ -99,35 +101,38 @@ def make(self, key): nwbf = get_nwb_file(nwb_file_abspath) config = get_config(nwb_file_abspath) - if "Electrode" in config: - electrode_config_dicts = { - electrode_dict["electrode_id"]: electrode_dict - for electrode_dict in config["Electrode"] - } - else: - electrode_config_dicts = dict() + electrode_config_dicts = { + electrode_dict["electrode_id"]: electrode_dict + for electrode_dict in config.get("Electrode", []) + } electrodes = nwbf.electrodes.to_dataframe() for elect_id, elect_data in electrodes.iterrows(): - key["electrode_id"] = elect_id - key["name"] = str(elect_id) - key["electrode_group_name"] = elect_data.group_name - key["region_id"] = BrainRegion.fetch_add( - region_name=elect_data.group.location + key.update( + { + "electrode_id": elect_id, + "name": str(elect_id), + "electrode_group_name": elect_data.group_name, + "region_id": BrainRegion.fetch_add( + region_name=elect_data.group.location + ), + "x": elect_data.x, + "y": elect_data.y, + "z": elect_data.z, + "x_warped": 0, + "y_warped": 0, + "z_warped": 0, + "contacts": "", + "filtering": elect_data.filtering, + "impedance": elect_data.get("imp"), + } ) - key["x"] = elect_data.x - key["y"] = elect_data.y - key["z"] = elect_data.z - key["x_warped"] = 0 - key["y_warped"] = 0 - key["z_warped"] = 0 - key["contacts"] = "" - key["filtering"] = elect_data.filtering - key["impedance"] = elect_data.get("imp") - - # rough check of whether the electrodes table was created by rec_to_nwb and has - # the appropriate custom columns used by rec_to_nwb - # TODO this could be better resolved by making an extension for the electrodes table + + # rough check of whether the electrodes table was created by + # rec_to_nwb and has the appropriate custom columns used by + # rec_to_nwb + # TODO: this could be better resolved by making an extension for the + # electrodes table if ( isinstance(elect_data.group.device, ndx_franklab_novela.Probe) and "probe_shank" in elect_data @@ -135,22 +140,29 @@ def make(self, key): and "bad_channel" in elect_data and "ref_elect_id" in elect_data ): - key["probe_id"] = elect_data.group.device.probe_type - key["probe_shank"] = elect_data.probe_shank - key["probe_electrode"] = elect_data.probe_electrode - key["bad_channel"] = ( - "True" if elect_data.bad_channel else "False" + key.update( + { + "probe_id": elect_data.group.device.probe_type, + "probe_shank": elect_data.probe_shank, + "probe_electrode": elect_data.probe_electrode, + "bad_channel": ( + "True" if elect_data.bad_channel else "False" + ), + "original_reference_electrode": elect_data.ref_elect_id, + } ) - key["original_reference_electrode"] = elect_data.ref_elect_id - # override with information from the config YAML based on primary key (electrode id) + # override with information from the config YAML based on primary + # key (electrode id) if elect_id in electrode_config_dicts: # check whether the Probe.Electrode being referenced exists query = Probe.Electrode & electrode_config_dicts[elect_id] if len(query) == 0: warnings.warn( - f"No Probe.Electrode exists that matches the data: {electrode_config_dicts[elect_id]}. " - f"The config YAML for Electrode with electrode_id {elect_id} will be ignored." + f"No Probe.Electrode exists that matches the data: " + f"{electrode_config_dicts[elect_id]}. The config YAML " + f"for Electrode with electrode_id {elect_id} will " + "be ignored." ) else: key.update(electrode_config_dicts[elect_id]) @@ -159,7 +171,7 @@ def make(self, key): @classmethod def create_from_config(cls, nwb_file_name: str): - """Create or update Electrode entries from what is specified in the config YAML file. + """Create/update Electrode entries from config YAML file. Parameters ---------- @@ -172,7 +184,7 @@ def create_from_config(cls, nwb_file_name: str): if "Electrode" not in config: return - # map electrode id to dictionary of electrode information from config YAML + # map electrode id to dictionary of electrode info from config YAML electrode_dicts = { electrode_dict["electrode_id"]: electrode_dict for electrode_dict in config["Electrode"] @@ -181,25 +193,26 @@ def create_from_config(cls, nwb_file_name: str): electrodes = nwbf.electrodes.to_dataframe() for nwbfile_elect_id, elect_data in electrodes.iterrows(): if nwbfile_elect_id in electrode_dicts: - # use the information in the electrodes table to start and then add (or overwrite) values from the - # config YAML - key = dict() - key["nwb_file_name"] = nwb_file_name - key["name"] = str(nwbfile_elect_id) - key["electrode_group_name"] = elect_data.group_name - key["region_id"] = BrainRegion.fetch_add( - region_name=elect_data.group.location - ) - key["x"] = elect_data.x - key["y"] = elect_data.y - key["z"] = elect_data.z - key["x_warped"] = 0 - key["y_warped"] = 0 - key["z_warped"] = 0 - key["contacts"] = "" - key["filtering"] = elect_data.filtering - key["impedance"] = elect_data.get("imp") - key.update(electrode_dicts[nwbfile_elect_id]) + # use the information in the electrodes table to start and then + # add (or overwrite) values from the config YAML + key = { + "nwb_file_name": nwb_file_name, + "name": str(nwbfile_elect_id), + "electrode_group_name": elect_data.group_name, + "region_id": BrainRegion.fetch_add( + region_name=elect_data.group.location + ), + "x": elect_data.x, + "y": elect_data.y, + "z": elect_data.z, + "x_warped": 0, + "y_warped": 0, + "z_warped": 0, + "contacts": "", + "filtering": elect_data.filtering, + "impedance": elect_data.get("imp"), + **electrode_dicts[nwbfile_elect_id], + } query = Electrode & {"electrode_id": nwbfile_elect_id} if len(query): cls.update1(key) @@ -211,8 +224,8 @@ def create_from_config(cls, nwb_file_name: str): print(f"Inserted Electrode with ID {nwbfile_elect_id}.") else: warnings.warn( - f"Electrode ID {nwbfile_elect_id} exists in the NWB file but has no corresponding " - "config YAML entry." + f"Electrode ID {nwbfile_elect_id} exists in the NWB file " + "but has no corresponding config YAML entry." ) @@ -223,8 +236,8 @@ class Raw(dj.Imported): -> Session --- -> IntervalList - raw_object_id: varchar(40) # the NWB object ID for loading this object from the file - sampling_rate: float # Sampling rate calculated from data, in Hz + raw_object_id: varchar(40) # NWB obj ID for loading this obj from the file + sampling_rate: float # Sampling rate calculated from data, in Hz comments: varchar(2000) description: varchar(2000) """ @@ -233,7 +246,8 @@ def make(self, key): nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) - raw_interval_name = "raw data valid times" + INTERVAL_LIST_NAME = "raw data valid times" + # get the acquisition object try: # TODO this assumes there is a single item in NWBFile.acquisition @@ -245,21 +259,22 @@ def make(self, key): + f"Skipping entry in {self.full_table_name}" ) return - if rawdata.rate is not None: - sampling_rate = rawdata.rate - else: - print("Estimating sampling rate...") - # NOTE: Only use first 1e6 timepoints to save time - sampling_rate = estimate_sampling_rate( + + sampling_rate = ( + rawdata.rate + if rawdata.rate is not None + else estimate_sampling_rate( np.asarray(rawdata.timestamps[: int(1e6)]), 1.5 ) - print(f"Estimated sampling rate: {sampling_rate}") + ) key["sampling_rate"] = sampling_rate + interval_dict = { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": INTERVAL_LIST_NAME, + } - interval_dict = dict() - interval_dict["nwb_file_name"] = key["nwb_file_name"] - interval_dict["interval_list_name"] = raw_interval_name if rawdata.rate is not None: + print(f"Estimated sampling rate: {sampling_rate}") interval_dict["valid_times"] = np.array( [[0, len(rawdata.data) / rawdata.rate]] ) @@ -273,25 +288,33 @@ def make(self, key): ) IntervalList().insert1(interval_dict, skip_duplicates=True) - # now insert each of the electrodes as an individual row, but with the same nwb_object_id - key["raw_object_id"] = rawdata.object_id - key["sampling_rate"] = sampling_rate - print(f'Importing raw data: Sampling rate:\t{key["sampling_rate"]} Hz') + # now insert each of the electrodes as an individual row, but with the + # same nwb_object_id + + key.update( + { + "raw_object_id": rawdata.object_id, + "sampling_rate": sampling_rate, + "interval_list_name": INTERVAL_LIST_NAME, + "comments": rawdata.comments, + "description": rawdata.description, + } + ) + print( + f'Importing raw data: Sampling rate:\t{key["sampling_rate"]} Hz' f'Number of valid intervals:\t{len(interval_dict["valid_times"])}' ) - key["interval_list_name"] = raw_interval_name - key["comments"] = rawdata.comments - key["description"] = rawdata.description + self.insert1(key, skip_duplicates=True) def nwb_object(self, key): - # TODO return the nwb_object; FIX: this should be replaced with a fetch call. Note that we're using the raw file - # so we can modify the other one. + # TODO return the nwb_object; FIX: this should be replaced with a fetch + # call. Note that we're using the raw file so we can modify the other + # one. nwb_file_name = key["nwb_file_name"] - nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) - nwbf = get_nwb_file(nwb_file_abspath) - raw_object_id = (self & {"nwb_file_name": key["nwb_file_name"]}).fetch1( + nwbf = get_nwb_file(Nwbfile.get_abs_path(nwb_file_name)) + raw_object_id = (self & {"nwb_file_name": nwb_file_name}).fetch1( "raw_object_id" ) return nwbf.objects[raw_object_id] @@ -306,7 +329,7 @@ class SampleCount(dj.Imported): # Sample count :s timestamp timeseries -> Session --- - sample_count_object_id: varchar(40) # the NWB object ID for loading this object from the file + sample_count_object_id: varchar(40) # NWB obj ID for loading from the file """ def make(self, key): @@ -318,7 +341,8 @@ def make(self, key): sample_count = get_data_interface(nwbf, "sample_count") if sample_count is None: print( - f'Unable to import SampleCount: no data interface named "sample_count" found in {nwb_file_name}.' + "Unable to import SampleCount: no data interface named " + + f'"sample_count" found in {nwb_file_name}.' ) return key["sample_count_object_id"] = sample_count.object_id @@ -341,7 +365,7 @@ class LFPElectrode(dj.Part): """ def set_lfp_electrodes(self, nwb_file_name, electrode_list): - """Removes all electrodes for the specified nwb file and then adds back the electrodes in the list + """Remove all electrodes from NWB file, add back in the list Parameters ---------- @@ -351,29 +375,22 @@ def set_lfp_electrodes(self, nwb_file_name, electrode_list): list of electrodes to be used for LFP """ + this_nwb = {"nwb_file_name": nwb_file_name} + this_selection = LFPSelection() & this_nwb + # remove the session and then recreate the session and Electrode list - (LFPSelection() & {"nwb_file_name": nwb_file_name}).delete() + this_selection.delete() + # check to see if the user allowed the deletion - if ( - len((LFPSelection() & {"nwb_file_name": nwb_file_name}).fetch()) - == 0 - ): - LFPSelection().insert1({"nwb_file_name": nwb_file_name}) - - # TODO: do this in a better way - all_electrodes = ( - Electrode() & {"nwb_file_name": nwb_file_name} - ).fetch(as_dict=True) - primary_key = Electrode.primary_key - for e in all_electrodes: - # create a dictionary so we can insert new elects - if e["electrode_id"] in electrode_list: - lfpelectdict = { - k: v for k, v in e.items() if k in primary_key - } - LFPSelection().LFPElectrode.insert1( - lfpelectdict, replace=True - ) + if not this_selection.fetch(): + electrodes = [ + e + for e in (Electrode & this_nwb).fetch("KEY", as_dict=True) + if e.get("electrode_id") in electrode_list + ] + + LFPSelection().insert1(this_nwb) + LFPSelection().LFPElectrode.insert(electrodes, replace=True) @schema @@ -384,14 +401,18 @@ class LFP(dj.Imported): -> IntervalList # the valid intervals for the data -> FirFilterParameters # the filter used for the data -> AnalysisNwbfile # the name of the nwb file with the lfp data - lfp_object_id: varchar(40) # the NWB object ID for loading this object from the file + lfp_object_id: varchar(40) # NWB obj ID for loading this obj from the file lfp_sampling_rate: float # the sampling rate, in HZ """ def make(self, key): - # get the NWB object with the data; FIX: change to fetch with additional infrastructure + # keep only the intervals > 1 second long + MIN_INTERVAL_LENGTH = 1.0 + + # get the NWB object with the data + # TODO: change to fetch with additional infrastructure rawdata = Raw().nwb_object(key) - sampling_rate, interval_list_name = (Raw() & key).fetch1( + sampling_rate, interval_list_name = (Raw & key).fetch1( "sampling_rate", "interval_list_name" ) sampling_rate = int(np.round(sampling_rate)) @@ -403,40 +424,45 @@ def make(self, key): "interval_list_name": interval_list_name, } ).fetch1("valid_times") - # keep only the intervals > 1 second long - min_interval_length = 1.0 - valid = [] - for count, interval in enumerate(valid_times): - if interval[1] - interval[0] > min_interval_length: - valid.append(count) - valid_times = valid_times[valid] + + valid_times = [ + interval + for interval in valid_times + if interval[1] - interval[0] > MIN_INTERVAL_LENGTH + ] print( - f"LFP: found {len(valid)} of {count+1} intervals > {min_interval_length} sec long." + f"LFP: found {len(valid_times)} intervals > " + + f"{MIN_INTERVAL_LENGTH} sec long." ) # target 1 KHz sampling rate decimation = sampling_rate // 1000 # get the LFP filter that matches the raw data - filter = ( + lfp_filter = ( FirFilterParameters() & {"filter_name": "LFP 0-400 Hz"} & {"filter_sampling_rate": sampling_rate} - ).fetch(as_dict=True) - - # there should only be one filter that matches, so we take the first of the dictionaries - key["filter_name"] = filter[0]["filter_name"] - key["filter_sampling_rate"] = filter[0]["filter_sampling_rate"] + ).fetch1() - filter_coeff = filter[0]["filter_coeff"] + filter_coeff = lfp_filter["filter_coeff"] if len(filter_coeff) == 0: print( - f"Error in LFP: no filter found with data sampling rate of {sampling_rate}" + f"Error in LFP: no filter found w/sampling rate {sampling_rate}" ) return None + + key.update( + { + "filter_name": lfp_filter["filter_name"], + "filter_sampling_rate": lfp_filter["filter_sampling_rate"], + } + ) + # get the list of selected LFP Channels from LFPElectrode - electrode_keys = (LFPSelection.LFPElectrode & key).fetch("KEY") - electrode_id_list = list(k["electrode_id"] for k in electrode_keys) + electrode_id_list = list( + (LFPSelection.LFPElectrode & key).fetch("electrode_id") + ) electrode_id_list.sort() lfp_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) @@ -454,17 +480,24 @@ def make(self, key): decimation, ) - # now that the LFP is filtered and in the file, add the file to the AnalysisNwbfile table + # now that the LFP is filtered and in the file, add the file to the + # AnalysisNwbfile table AnalysisNwbfile().add(key["nwb_file_name"], lfp_file_name) - key["analysis_file_name"] = lfp_file_name - key["lfp_object_id"] = lfp_object_id - key["lfp_sampling_rate"] = sampling_rate // decimation + # Update the key with the LFP information + key.update( + { + "analysis_file_name": lfp_file_name, + "lfp_object_id": lfp_object_id, + "lfp_sampling_rate": sampling_rate // decimation, + "interval_list_name": "lfp valid times", + } + ) - # finally, we need to censor the valid times to account for the downsampling + # finally, censor the valid times to account for the downsampling. + # Add interval list for the LFP valid times, skipping duplicates lfp_valid_times = interval_list_censor(valid_times, timestamp_interval) - # add an interval list for the LFP valid times, skipping duplicates - key["interval_list_name"] = "lfp valid times" + IntervalList.insert1( { "nwb_file_name": key["nwb_file_name"], @@ -505,130 +538,117 @@ def fetch1_dataframe(self, *attrs, **kwargs): class LFPBandSelection(dj.Manual): definition = """ -> LFP - -> FirFilterParameters # the filter to use for the data - -> IntervalList.proj(target_interval_list_name='interval_list_name') # the original set of times to be filtered + -> FirFilterParameters # the filter to use for the data + -> IntervalList.proj(target_interval_list_name='interval_list_name') + # the original set of times to be filtered lfp_band_sampling_rate: int # the sampling rate for this band --- - min_interval_len = 1: float # the minimum length of a valid interval to filter + min_interval_len = 1: float # minimum length of a valid interval to filter """ class LFPBandElectrode(dj.Part): definition = """ -> LFPBandSelection -> LFPSelection.LFPElectrode # the LFP electrode to be filtered - reference_elect_id = -1: int # the reference electrode to use; -1 for no reference + reference_elect_id = -1: int # reference electrode; -1 for no reference --- """ - def set_lfp_band_electrodes( - self, - nwb_file_name, - electrode_list, - filter_name, - interval_list_name, - reference_electrode_list, - lfp_band_sampling_rate, - ): - """ - Adds an entry for each electrode in the electrode_list with the specified filter, interval_list, and - reference electrode. - Also removes any entries that have the same filter, interval list and reference electrode but are not - in the electrode_list. - :param nwb_file_name: string - the name of the nwb file for the desired session - :param electrode_list: list of LFP electrodes to be filtered - :param filter_name: the name of the filter (from the FirFilterParameters schema) - :param interval_name: the name of the interval list (from the IntervalList schema) - :param reference_electrode_list: A single electrode id corresponding to the reference to use for all + +def set_lfp_band_electrodes( + self, + nwb_file_name: str, + electrode_list: list, + filter_name: str, + interval_list_name: str, + reference_electrode_list: list, + lfp_band_sampling_rate: int, +) -> None: + """ + Add each in electrode_list w/ filter, interval_list, and ref electrode. + + Also removes any entries that have the same filter, interval list, and + reference electrode but are not in the electrode_list. + + Parameters + ---------- + nwb_file_name: str + The name of the nwb file for the desired session + electrode_list: List[int] + List of LFP electrodes to be filtered + filter_name: str + The name of the filter (from the FirFilterParameters schema) + interval_list_name: str + The name of the interval list (from the IntervalList schema) + reference_electrode_list: Union[int, List[int]] + A single electrode id corresponding to the reference to use for all electrodes or a list with one element per entry in the electrode_list - :param lfp_band_sampling_rate: The output sampling rate to be used for the filtered data; must be an + lfp_band_sampling_rate: int + The output sampling rate to be used for the filtered data; must be an integer divisor of the LFP sampling rate - :return: none - """ - # Error checks on parameters - # electrode_list - query = LFPSelection().LFPElectrode() & {"nwb_file_name": nwb_file_name} - available_electrodes = query.fetch("electrode_id") - if not np.all(np.isin(electrode_list, available_electrodes)): - raise ValueError( - "All elements in electrode_list must be valid electrode_ids in the LFPSelection table" - ) - # sampling rate - lfp_sampling_rate = (LFP() & {"nwb_file_name": nwb_file_name}).fetch1( - "lfp_sampling_rate" + + Returns + ------- + None + """ + this_file = {"nwb_file_name": nwb_file_name} + + available_electrodes = (LFPSelection().LFPElectrode() & this_file).fetch( + "electrode_id" + ) + if not np.all(np.isin(electrode_list, available_electrodes)): + raise ValueError( + "All elements in electrode_list must be valid electrode_ids in the " + + "LFPSelection table" ) - decimation = lfp_sampling_rate // lfp_band_sampling_rate - if lfp_sampling_rate // decimation != lfp_band_sampling_rate: - raise ValueError( - f"lfp_band_sampling rate {lfp_band_sampling_rate} is not an integer divisor of lfp " - f"samping rate {lfp_sampling_rate}" - ) - # filter - query = FirFilterParameters() & { - "filter_name": filter_name, - "filter_sampling_rate": lfp_sampling_rate, - } - if not query: - raise ValueError( - f"filter {filter_name}, sampling rate {lfp_sampling_rate} is not in the FirFilterParameters table" - ) - # interval_list - query = IntervalList() & { - "nwb_file_name": nwb_file_name, - "interval_name": interval_list_name, - } - if not query: - raise ValueError( - f"interval list {interval_list_name} is not in the IntervalList table; the list must be " - "added before this function is called" - ) - # reference_electrode_list - if len(reference_electrode_list) != 1 and len( - reference_electrode_list - ) != len(electrode_list): - raise ValueError( - "reference_electrode_list must contain either 1 or len(electrode_list) elements" - ) - # add a -1 element to the list to allow for the no reference option - available_electrodes = np.append(available_electrodes, [-1]) - if not np.all(np.isin(reference_electrode_list, available_electrodes)): - raise ValueError( - "All elements in reference_electrode_list must be valid electrode_ids in the LFPSelection " - "table" - ) - # make a list of all the references - ref_list = np.zeros((len(electrode_list),)) - ref_list[:] = reference_electrode_list - - key = dict() - key["nwb_file_name"] = nwb_file_name - key["filter_name"] = filter_name - key["filter_sampling_rate"] = lfp_sampling_rate - key["target_interval_list_name"] = interval_list_name - key["lfp_band_sampling_rate"] = lfp_sampling_rate // decimation - # insert an entry into the main LFPBandSelectionTable - self.insert1(key, skip_duplicates=True) + lfp_sampling_rate = (LFP() & this_file).fetch1("lfp_sampling_rate") - # get all of the current entries and delete any that are not in the list - elect_id, ref_id = (self.LFPBandElectrode() & key).fetch( - "electrode_id", "reference_elect_id" + decimation = lfp_sampling_rate // lfp_band_sampling_rate + if lfp_sampling_rate // decimation != lfp_band_sampling_rate: + raise ValueError( + f"lfp_band_sampling rate {lfp_band_sampling_rate} is not an integer" + + f" divisor of lfp samping rate {lfp_sampling_rate}" + ) + + # filter + query = FirFilterParameters() & { + "filter_name": filter_name, + "filter_sampling_rate": lfp_sampling_rate, + } + if not query: + raise ValueError( + f"filter {filter_name}, sampling rate {lfp_sampling_rate} is not in" + + " the FirFilterParameters table" + ) + + # interval_list + query = IntervalList() & { + "nwb_file_name": nwb_file_name, + "interval_name": interval_list_name, + } + if not query: + raise ValueError( + f"interval list {interval_list_name} is not in the IntervalList " + + "table; the list must be added before this function is called" + ) + + # reference_electrode_list + if len(reference_electrode_list) != 1 and len( + reference_electrode_list + ) != len(electrode_list): + raise ValueError( + "reference_electrode_list must contain either 1 or " + + "len(electrode_list) elements" + ) + + # add a -1 element to the list to allow for the no reference option + available_electrodes = np.append(available_electrodes, [-1]) + if not np.all(np.isin(reference_electrode_list, available_electrodes)): + raise ValueError( + "All elements in reference_electrode_list must be valid " + + "electrode_ids in the LFPSelection table" ) - for e, r in zip(elect_id, ref_id): - if not len(np.where((electrode_list == e) & (ref_list == r))[0]): - key["electrode_id"] = e - key["reference_elect_id"] = r - (self.LFPBandElectrode() & key).delete() - - # iterate through all of the new elements and add them - for e, r in zip(electrode_list, ref_list): - key["electrode_id"] = e - query = Electrode & { - "nwb_file_name": nwb_file_name, - "electrode_id": e, - } - key["electrode_group_name"] = query.fetch1("electrode_group_name") - key["reference_elect_id"] = r - self.LFPBandElectrode().insert1(key, skip_duplicates=True) @schema @@ -638,19 +658,20 @@ class LFPBand(dj.Computed): --- -> AnalysisNwbfile -> IntervalList - filtered_data_object_id: varchar(40) # the NWB object ID for loading this object from the file + filtered_data_object_id: varchar(40) # NWB obj ID for loading from file """ def make(self, key): - # get the NWB object with the lfp data; FIX: change to fetch with additional infrastructure - lfp_object = ( - LFP() & {"nwb_file_name": key["nwb_file_name"]} - ).fetch_nwb()[0]["lfp"] + this_file = {"nwb_file_name": key["nwb_file_name"]} + this_band = LFPBandSelection() & key + this_electrode = LFPBandSelection.LFPBandElectrode & key + this_interval = IntervalList() & this_file + lfp_object = (LFP() & this_file).fetch_nwb()[0]["lfp"] # get the electrodes to be filtered and their references - lfp_band_elect_id, lfp_band_ref_id = ( - LFPBandSelection().LFPBandElectrode() & key - ).fetch("electrode_id", "reference_elect_id") + lfp_band_elect_id, lfp_band_ref_id = this_electrode.fetch( + "electrode_id", "reference_elect_id" + ) # sort the electrodes to make sure they are in ascending order lfp_band_elect_id = np.asarray(lfp_band_elect_id) @@ -659,39 +680,36 @@ def make(self, key): lfp_band_elect_id = lfp_band_elect_id[lfp_sort_order] lfp_band_ref_id = lfp_band_ref_id[lfp_sort_order] - lfp_sampling_rate = ( - LFP() & {"nwb_file_name": key["nwb_file_name"]} - ).fetch1("lfp_sampling_rate") - interval_list_name, lfp_band_sampling_rate = ( - LFPBandSelection() & key - ).fetch1("target_interval_list_name", "lfp_band_sampling_rate") + lfp_sampling_rate = (LFP() & this_file).fetch1("lfp_sampling_rate") + interval_list_name, lfp_band_sampling_rate = this_band.fetch1( + "target_interval_list_name", "lfp_band_sampling_rate" + ) + valid_times = ( - IntervalList() - & { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": interval_list_name, - } + this_interval & {"interval_list_name": interval_list_name} ).fetch1("valid_times") - # the valid_times for this interval may be slightly beyond the valid times for the lfp itself, - # so we have to intersect the two - lfp_interval_list = ( - LFP() & {"nwb_file_name": key["nwb_file_name"]} - ).fetch1("interval_list_name") + + # Intersect valid_times and LFP valid times lfp_valid_times = ( - IntervalList() + this_interval & { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": lfp_interval_list, + "interval_list_name": (LFP() & this_file).fetch1( + "interval_list_name" + ), } ).fetch1("valid_times") - min_length = (LFPBandSelection & key).fetch1("min_interval_len") + lfp_band_valid_times = interval_list_intersect( - valid_times, lfp_valid_times, min_length=min_length + valid_times, + lfp_valid_times, + min_length=this_band.fetch1("min_interval_len"), ) - filter_name, filter_sampling_rate, lfp_band_sampling_rate = ( - LFPBandSelection() & key - ).fetch1( + ( + filter_name, + filter_sampling_rate, + lfp_band_sampling_rate, + ) = this_band.fetch1( "filter_name", "filter_sampling_rate", "lfp_band_sampling_rate" ) @@ -699,7 +717,8 @@ def make(self, key): # load in the timestamps timestamps = np.asarray(lfp_object.timestamps) - # get the indices of the first timestamp and the last timestamp that are within the valid times + + # Indices of first/last timestamp within the valid times included_indices = interval_list_contains_ind( lfp_band_valid_times, timestamps ) @@ -732,21 +751,24 @@ def make(self, key): ) # get the LFP filter that matches the raw data - filter = ( + lfp_filter = ( FirFilterParameters() & {"filter_name": filter_name} & {"filter_sampling_rate": filter_sampling_rate} ).fetch(as_dict=True) - if len(filter) == 0: + + if len(lfp_filter) == 0: raise ValueError( - f"Filter {filter_name} and sampling_rate {lfp_band_sampling_rate} does not exit in the " + f"Filter {filter_name} and sampling_rate " + f"{lfp_band_sampling_rate} does not exit in the " "FirFilterParameters table" ) - filter_coeff = filter[0]["filter_coeff"] + filter_coeff = lfp_filter[0]["filter_coeff"] if len(filter_coeff) == 0: print( - f"Error in LFPBand: no filter found with data sampling rate of {lfp_band_sampling_rate}" + "Error in LFPBand: no filter found with data sampling rate of " + f"{lfp_band_sampling_rate}" ) return None @@ -765,20 +787,19 @@ def make(self, key): decimation, ) - # now that the LFP is filtered, we create an electrical series for it and add it to the file + # Create an electrical series for filtered data with pynwb.NWBHDF5IO( path=lfp_band_file_abspath, mode="a", load_namespaces=True ) as io: nwbf = io.read() - # get the indices of the electrodes in the electrode table of the file to get the right values + # get indices of the electrodes in electrode table of file elect_index = get_electrode_indices(nwbf, lfp_band_elect_id) electrode_table_region = nwbf.create_electrode_table_region( elect_index, "filtered electrode table" ) - eseries_name = "filtered data" # TODO: use datatype of data es = pynwb.ecephys.ElectricalSeries( - name=eseries_name, + name="filtered data", data=filtered_data, electrodes=electrode_table_region, timestamps=new_timestamps, @@ -787,26 +808,26 @@ def make(self, key): nwbf.add_scratch(es) io.write(nwbf) filtered_data_object_id = es.object_id - # + # add the file to the AnalysisNwbfile table AnalysisNwbfile().add(key["nwb_file_name"], lfp_band_file_name) - key["analysis_file_name"] = lfp_band_file_name - key["filtered_data_object_id"] = filtered_data_object_id - - # finally, we need to censor the valid times to account for the downsampling if this is the first time we've - # downsampled these data - key["interval_list_name"] = ( - interval_list_name - + " lfp band " - + str(lfp_band_sampling_rate) - + "Hz" + key.update( + { + "analysis_file_name": lfp_band_file_name, + "filtered_data_object_id": filtered_data_object_id, + "interval_list_name": ( + interval_list_name + + " lfp band " + + str(lfp_band_sampling_rate) + + "Hz" + ), + } ) + + # Censor the valid times to account for downsampling + # ... if this is the first time we've downsampled these data tmp_valid_times = ( - IntervalList - & { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": key["interval_list_name"], - } + this_interval & {"interval_list_name": key["interval_list_name"]} ).fetch("valid_times") if len(tmp_valid_times) == 0: lfp_band_valid_times = interval_list_censor( @@ -848,7 +869,8 @@ def fetch1_dataframe(self, *attrs, **kwargs): @schema class ElectrodeBrainRegion(dj.Manual): definition = """ - # Table with brain region of electrodes determined post-experiment e.g. via histological analysis or CT + # Table with brain region of electrodes determined post-experiment + # e.g. via histological analysis or CT -> Electrode --- -> BrainRegion diff --git a/src/spyglass/common/common_filter.py b/src/spyglass/common/common_filter.py index f7f741984..a2550330e 100644 --- a/src/spyglass/common/common_filter.py +++ b/src/spyglass/common/common_filter.py @@ -21,8 +21,9 @@ def _import_ghostipy(): return gsp except ImportError as e: raise ImportError( - "You must install ghostipy to use filtering methods. Please note that to install ghostipy on " - "an Mac M1, you must first install pyfftw from conda-forge." + "You must install ghostipy to use filtering methods. Please note " + "that to install ghostipy on an Mac M1, you must first install " + "pyfftw from conda-forge." ) from e @@ -33,201 +34,264 @@ class FirFilterParameters(dj.Manual): filter_sampling_rate: int # sampling rate for this filter --- filter_type: enum("lowpass", "highpass", "bandpass") - filter_low_stop = 0: float # lowest frequency for stop band for low frequency side of filter - filter_low_pass = 0: float # lowest frequency for pass band of low frequency side of filter - filter_high_pass = 0: float # highest frequency for pass band for high frequency side of filter - filter_high_stop = 0: float # highest frequency for stop band of high frequency side of filter + filter_low_stop = 0: float # lowest freq for stop band for low filt + filter_low_pass = 0: float # lowest freq for pass band of low filt + filter_high_pass = 0: float # hi'est freq for pass band for high filt + filter_high_stop = 0: float # hi'est freq for stop band of high filt filter_comments: varchar(2000) # comments about the filter - filter_band_edges: blob # numpy array containing the filter bands (redundant with individual parameters) - filter_coeff: longblob # numpy array containing the filter coefficients + filter_band_edges: blob # numpy array of filter bands + # redundant with individual parameters + filter_coeff: longblob # numpy array of filter coefficients """ - def add_filter(self, filter_name, fs, filter_type, band_edges, comments=""): - gsp = _import_ghostipy() + def add_filter( + self, + filter_name: str, + fs: float, + filter_type: str, + band_edges: list, + comments: str = "", + ) -> None: + """Add filter to the Filter table. + + Parameters + ---------- + filter_name: str + The name of the filter. + fs: float + The filter sampling rate. + filter_type: str + The type of the filter ('lowpass', 'highpass', or 'bandpass'). + band_edges: List[float] + The band edges for the filter. + comments: str, optional) + Additional comments for the filter. Default "". + + Returns + ------- + None + Returns None if there is an error in the filter type or band + frequencies. + + Raises + ------ + Exception: + Raises an exception if an unexpected filter type is encountered. + """ + VALID_FILTERS = {"lowpass": 2, "highpass": 2, "bandpass": 4} + FILTER_ERR = "Error in Filter.add_filter: " + FILTER_N_ERR = FILTER_ERR + "filter {} requires {} band_frequencies." - # add an FIR bandpass filter of the specified type ('lowpass', 'highpass', or 'bandpass'). + # add an FIR bandpass filter of the specified type. # band_edges should be as follows: - # low pass filter: [high_pass high_stop] - # high pass filter: [low stop low pass] - # band pass filter: [low_stop low_pass high_pass high_stop]. - if filter_type not in ["lowpass", "highpass", "bandpass"]: + # low pass : [high_pass high_stop] + # high pass: [low stop low pass] + # band pass: [low_stop low_pass high_pass high_stop]. + if filter_type not in VALID_FILTERS: print( - "Error in Filter.add_filter: filter type {} is not " - "lowpass" - ", " - "highpass" - " or " - """bandpass""".format(filter_type) + FILTER_ERR + + f"{filter_type} not valid type: {VALID_FILTERS.keys()}" ) return None - p = 2 # transition spline will be quadratic - if filter_type == "lowpass" or filter_type == "highpass": - # check that two frequencies were passed in and that they are in the right order - if len(band_edges) != 2: - print( - "Error in Filter.add_filter: lowpass and highpass filter requires two band_frequencies" - ) - return None - tw = band_edges[1] - band_edges[0] + if not len(band_edges) == VALID_FILTERS[filter_type]: + print(FILTER_N_ERR.format(filter_name, VALID_FILTERS[filter_type])) + return None - elif filter_type == "bandpass": - if len(band_edges) != 4: - print( - "Error in Filter.add_filter: bandpass filter requires four band_frequencies." - ) - return None - # the transition width is the mean of the widths of left and right transition regions - tw = ( + gsp = _import_ghostipy() + TRANS_SPLINE = 2 # transition spline will be quadratic + + if filter_type != "bandpass": + transition_width = band_edges[1] - band_edges[0] + + else: + # transition width is mean of left and right transition regions + transition_width = ( (band_edges[1] - band_edges[0]) + (band_edges[3] - band_edges[2]) ) / 2.0 - else: - raise Exception(f"Unexpected filter type: {filter_type}") - - numtaps = gsp.estimate_taps(fs, tw) - filterdict = dict() - filterdict["filter_name"] = filter_name - filterdict["filter_sampling_rate"] = fs - filterdict["filter_comments"] = comments + numtaps = gsp.estimate_taps(fs, transition_width) + filterdict = { + "filter_type": filter_type, + "filter_name": filter_name, + "filter_sampling_rate": fs, + "filter_comments": comments, + "filter_low_stop": 0, + "filter_low_pass": 0, + "filter_high_pass": 0, + "filter_high_stop": 0, + "filter_band_edges": np.asarray(band_edges), + } # set the desired frequency response if filter_type == "lowpass": desired = [1, 0] - filterdict["filter_low_stop"] = 0 - filterdict["filter_low_pass"] = 0 - filterdict["filter_high_pass"] = band_edges[0] - filterdict["filter_high_stop"] = band_edges[1] + pass_stop_dict = { + "filter_high_pass": band_edges[0], + "filter_high_stop": band_edges[1], + } elif filter_type == "highpass": desired = [0, 1] - filterdict["filter_low_stop"] = band_edges[0] - filterdict["filter_low_pass"] = band_edges[1] - filterdict["filter_high_pass"] = 0 - filterdict["filter_high_stop"] = 0 + pass_stop_dict = { + "filter_low_stop": band_edges[0], + "filter_low_pass": band_edges[1], + } else: desired = [0, 1, 1, 0] - filterdict["filter_low_stop"] = band_edges[0] - filterdict["filter_low_pass"] = band_edges[1] - filterdict["filter_high_pass"] = band_edges[2] - filterdict["filter_high_stop"] = band_edges[3] - filterdict["filter_type"] = filter_type - filterdict["filter_band_edges"] = np.asarray(band_edges) + pass_stop_dict = { + "filter_low_stop": band_edges[0], + "filter_low_pass": band_edges[1], + "filter_high_pass": band_edges[2], + "filter_high_stop": band_edges[3], + } + # create 1d array for coefficients - filterdict["filter_coeff"] = np.array( - gsp.firdesign(numtaps, band_edges, desired, fs=fs, p=p), ndmin=1 + filterdict.update( + { + **pass_stop_dict, + "filter_coeff": np.array( + gsp.firdesign( + numtaps, band_edges, desired, fs=fs, p=TRANS_SPLINE + ), + ndmin=1, + ), + } ) - # add this filter to the table + self.insert1(filterdict, skip_duplicates=True) - def plot_magnitude(self, filter_name, fs): - filter = ( + def _filter_restrict(self, filter_name, fs): + return ( self & {"filter_name": filter_name} & {"filter_sampling_rate": fs} - ).fetch(as_dict=True) - f = filter[0] + ).fetch1(as_dict=True) + + def plot_magnitude(self, filter_name, fs): + filter_dict = self._filter_restrict(filter_name, fs) plt.figure() - w, h = signal.freqz(filter[0]["filter_coeff"], worN=65536) + w, h = signal.freqz(filter_dict["filter_coeff"], worN=65536) magnitude = 20 * np.log10(np.abs(h)) plt.plot(w / np.pi * fs / 2, magnitude) plt.xlabel("Frequency (Hz)") plt.ylabel("Magnitude") plt.title("Frequency Response") - plt.xlim(0, np.max(f["filter_coeffand_edges"] * 2)) + plt.xlim(0, np.max(filter_dict["filter_coeffand_edges"] * 2)) plt.ylim(np.min(magnitude), -1 * np.min(magnitude) * 0.1) plt.grid(True) def plot_fir_filter(self, filter_name, fs): - filter = ( - self & {"filter_name": filter_name} & {"filter_sampling_rate": fs} - ).fetch(as_dict=True) - f = filter[0] + filter_dict = self._filter_restrict(filter_name, fs) plt.figure() plt.clf() - plt.plot(f["filter_coeff"], "k") + plt.plot(filter_dict["filter_coeff"], "k") plt.xlabel("Coefficient") plt.ylabel("Magnitude") plt.title("Filter Taps") plt.grid(True) def filter_delay(self, filter_name, fs): - # return the filter delay - filter = ( - self & {"filter_name": filter_name} & {"filter_sampling_rate": fs} - ).fetch(as_dict=True) - return self.calc_filter_delay(filter["filter_coeff"]) + return self.calc_filter_delay( + self._filter_restrict(filter_name, fs)["filter_coeff"] + ) + + def _time_bound_check(self, start, stop, all, nsamples): + timestamp_warn = "Interval time warning: " + if start < all[0]: + warnings.warn( + timestamp_warn + + "start time smaller than first timestamp, " + + f"substituting first: {start} < {all[0]}" + ) + start = all[0] + + if stop > all[-1]: + warnings.warn( + timestamp_warn + + "stop time larger than last timestamp, " + + f"substituting last: {stop} < {all[-1]}" + ) + stop = all[-1] + + frm, to = np.searchsorted(all, (start, stop)) + to = min(to, nsamples) + return frm, to def filter_data_nwb( self, - analysis_file_abs_path, - eseries, - filter_coeff, - valid_times, - electrode_ids, - decimation, + analysis_file_abs_path: str, + eseries: pynwb.ecephys.ElectricalSeries, + filter_coeff: np.ndarray, + valid_times: np.ndarray, + electrode_ids: list, + decimation: int, description: str = "filtered data", type: Union[None, str] = None, ): """ - :param analysis_nwb_file_name: str full path to previously created analysis nwb file where filtered data - should be stored. This also has the name of the original NWB file where the data will be taken from - :param eseries: electrical series with data to be filtered - :param filter_coeff: numpy array with filter coefficients for FIR filter - :param valid_times: 2D numpy array with start and stop times of intervals to be filtered - :param electrode_ids: list of electrode_ids to filter - :param decimation: int decimation factor - :return: The NWB object id of the filtered data (str), list containing first and last timestamp - - This function takes data and timestamps from an NWB electrical series and filters them using the ghostipy - package, saving the result as a new electricalseries in the nwb_file_name, which should have previously been - created and linked to the original NWB file using common_session.AnalysisNwbfile.create() + Filter data from an NWB electrical series using the ghostipy package, + and save the result as a new electrical series in the analysis NWB file. + + Parameters + ---------- + analysis_file_abs_path : str + Full path to the analysis NWB file. + eseries : pynwb.ecephys.ElectricalSeries + Electrical series with data to be filtered. + filter_coeff : np.ndarray + Array with filter coefficients for FIR filter. + valid_times : np.ndarray + Array with start and stop times of intervals to be filtered. + electrode_ids : list + List of electrode IDs to filter. + decimation : int + Decimation factor. + description : str + Description of the filtered data. + data_type : Union[None, str] + Type of data (e.g., "LFP"). + + Returns + ------- + tuple + The NWB object ID of the filtered data and a list containing the + first and last timestamps. """ + MEM_USE_LIMIT = 0.9 # % of RAM use permited + gsp = _import_ghostipy() data_on_disk = eseries.data timestamps_on_disk = eseries.timestamps - n_dim = len(data_on_disk.shape) + n_samples = len(timestamps_on_disk) - # find the time_axis = 0 if data_on_disk.shape[0] == n_samples else 1 electrode_axis = 1 - time_axis + n_electrodes = data_on_disk.shape[electrode_axis] - input_dim_restrictions = [None] * n_dim + input_dim_restrictions = [None] * len(data_on_disk.shape) - # to get the input dimension restrictions we need to look at the electrode table for the eseries and get - # the indices from that + # Get input dimension restrictions input_dim_restrictions[electrode_axis] = np.s_[ get_electrode_indices(eseries, electrode_ids) ] indices = [] - output_shape_list = [0] * n_dim + output_shape_list = input_dim_restrictions output_shape_list[electrode_axis] = len(electrode_ids) - output_offsets = [0] - - timestamp_size = timestamps_on_disk[0].itemsize - timestamp_dtype = timestamps_on_disk[0].dtype - data_size = data_on_disk[0][0].itemsize data_dtype = data_on_disk[0][0].dtype filter_delay = self.calc_filter_delay(filter_coeff) + + output_offsets = [0] + for a_start, a_stop in valid_times: - if a_start < timestamps_on_disk[0]: - warnings.warn( - f"Interval start time {a_start} is smaller than first timestamp {timestamps_on_disk[0]}, " - "using first timestamp instead" - ) - a_start = timestamps_on_disk[0] - if a_stop > timestamps_on_disk[-1]: - warnings.warn( - f"Interval stop time {a_stop} is larger than last timestamp {timestamps_on_disk[-1]}, " - "using last timestamp instead" - ) - a_stop = timestamps_on_disk[-1] - frm, to = np.searchsorted(timestamps_on_disk, (a_start, a_stop)) - if to > n_samples: - to = n_samples + frm, to = self._time_bound_check( + a_start, a_stop, timestamps_on_disk, n_samples + ) + indices.append((frm, to)) - shape, dtype = gsp.filter_data_fir( + + shape, _ = gsp.filter_data_fir( data_on_disk, filter_coeff, axis=time_axis, @@ -240,119 +304,107 @@ def filter_data_nwb( output_offsets.append(output_offsets[-1] + shape[time_axis]) output_shape_list[time_axis] += shape[time_axis] - # open the nwb file to create the dynamic table region and electrode series, then write and close the file + # Create dynamic table region and electrode series, write/close file with pynwb.NWBHDF5IO( path=analysis_file_abs_path, mode="a", load_namespaces=True ) as io: nwbf = io.read() + # get the indices of the electrodes in the electrode table elect_ind = get_electrode_indices(nwbf, electrode_ids) electrode_table_region = nwbf.create_electrode_table_region( elect_ind, "filtered electrode table" ) - eseries_name = "filtered data" es = pynwb.ecephys.ElectricalSeries( - name=eseries_name, + name="filtered data", data=np.empty(tuple(output_shape_list), dtype=data_dtype), electrodes=electrode_table_region, timestamps=np.empty(output_shape_list[time_axis]), description=description, ) if type == "LFP": - lfp = pynwb.ecephys.LFP(electrical_series=es) ecephys_module = nwbf.create_processing_module( name="ecephys", description=description ) - ecephys_module.add(lfp) + ecephys_module.add(pynwb.ecephys.LFP(electrical_series=es)) else: nwbf.add_scratch(es) + io.write(nwbf) - # reload the NWB file to get the h5py objects for the data and the timestamps - with pynwb.NWBHDF5IO( - path=analysis_file_abs_path, mode="a", load_namespaces=True - ) as io: - nwbf = io.read() - es = nwbf.objects[es.object_id] - filtered_data = es.data - new_timestamps = es.timestamps - indices = np.array(indices, ndmin=2) - # Filter and write the output dataset - ts_offset = 0 - - print("Filtering data") - for ii, (start, stop) in enumerate(indices): - # calculate the size of the timestamps and the data and determine whether they - # can be loaded into < 90% of available RAM - mem = psutil.virtual_memory() - interval_samples = stop - start - if ( - interval_samples - * (timestamp_size + n_electrodes * data_size) - < 0.9 * mem.available - ): - print(f"Interval {ii}: loading data into memory") - timestamps = np.asarray( - timestamps_on_disk[start:stop], - dtype=timestamp_dtype, - ) - if time_axis == 0: - data = np.asarray( - data_on_disk[start:stop, :], dtype=data_dtype - ) - else: - data = np.asarray( - data_on_disk[:, start:stop], dtype=data_dtype - ) - extracted_ts = timestamps[0::decimation] - new_timestamps[ - ts_offset : ts_offset + len(extracted_ts) - ] = extracted_ts - ts_offset += len(extracted_ts) - # filter the data - gsp.filter_data_fir( - data, - filter_coeff, - axis=time_axis, - input_index_bounds=[0, interval_samples - 1], - output_index_bounds=[ - filter_delay, - filter_delay + stop - start, - ], - ds=decimation, - input_dim_restrictions=input_dim_restrictions, - outarray=filtered_data, - output_offset=output_offsets[ii], + # Reload NWB file to get h5py objects for data/timestamps + # NOTE: CBroz - why io context within io context? Unindenting + with pynwb.NWBHDF5IO( + path=analysis_file_abs_path, mode="a", load_namespaces=True + ) as io: + nwbf = io.read() + es = nwbf.objects[es.object_id] + filtered_data = es.data + new_timestamps = es.timestamps + indices = np.array(indices, ndmin=2) + # Filter and write the output dataset + ts_offset = 0 + + print("Filtering data") + for ii, (start, stop) in enumerate(indices): + # Calc size of timestamps + data, check if < 90% of RAM + interval_samples = stop - start + req_mem = interval_samples * ( + timestamps_on_disk[0].itemsize + + n_electrodes * data_on_disk[0][0].itemsize + ) + if req_mem < MEM_USE_LIMIT * psutil.virtual_memory(): + print(f"Interval {ii}: loading data into memory") + timestamps = np.asarray( + timestamps_on_disk[start:stop], + dtype=timestamps_on_disk[0].dtype, + ) + if time_axis == 0: + data = np.asarray( + data_on_disk[start:stop, :], dtype=data_dtype ) else: - print(f"Interval {ii}: leaving data on disk") - data = data_on_disk - timestamps = timestamps_on_disk - extracted_ts = timestamps[start:stop:decimation] - new_timestamps[ - ts_offset : ts_offset + len(extracted_ts) - ] = extracted_ts - ts_offset += len(extracted_ts) - # filter the data - gsp.filter_data_fir( - data, - filter_coeff, - axis=time_axis, - input_index_bounds=[start, stop], - output_index_bounds=[ - filter_delay, - filter_delay + stop - start, - ], - ds=decimation, - input_dim_restrictions=input_dim_restrictions, - outarray=filtered_data, - output_offset=output_offsets[ii], + data = np.asarray( + data_on_disk[:, start:stop], dtype=data_dtype ) + extracted_ts = timestamps[0::decimation] + new_timestamps[ + ts_offset : ts_offset + len(extracted_ts) + ] = extracted_ts + ts_offset += len(extracted_ts) + input_index_bounds = ([0, interval_samples - 1],) + + else: + print(f"Interval {ii}: leaving data on disk") + data = data_on_disk + timestamps = timestamps_on_disk + extracted_ts = timestamps[start:stop:decimation] + new_timestamps[ + ts_offset : ts_offset + len(extracted_ts) + ] = extracted_ts + ts_offset += len(extracted_ts) + input_index_bounds = [start, stop] + + # filter the data + gsp.filter_data_fir( + data, + filter_coeff, + axis=time_axis, + input_index_bounds=input_index_bounds, + output_index_bounds=[ + filter_delay, + filter_delay + stop - start, + ], + ds=decimation, + input_dim_restrictions=input_dim_restrictions, + outarray=filtered_data, + output_offset=output_offsets[ii], + ) - start_end = [new_timestamps[0], new_timestamps[-1]] + start_end = [new_timestamps[0], new_timestamps[-1]] - io.write(nwbf) + io.write(nwbf) return es.object_id, start_end @@ -366,14 +418,26 @@ def filter_data( decimation, ): """ - :param timestamps: numpy array with list of timestamps for data - :param data: original data array - :param filter_coeff: numpy array with filter coefficients for FIR filter - :param valid_times: 2D numpy array with start and stop times of intervals to be filtered - :param electrodes: list of electrodes to filter - :param decimation: decimation factor - :return: filtered_data, timestamps + Parameters + ---------- + timestamps: numpy array + List of timestamps for data + data: + original data array + filter_coeff: numpy array + Filter coefficients for FIR filter + valid_times: 2D numpy array + Start and stop times of intervals to be filtered + electrodes: list + Electrodes to filter + decimation: + decimation factor + + Return + ------ + filtered_data, timestamps """ + gsp = _import_ghostipy() n_dim = len(data.shape) @@ -390,24 +454,13 @@ def filter_data( filter_delay = self.calc_filter_delay(filter_coeff) for a_start, a_stop in valid_times: - if a_start < timestamps[0]: - print( - f"Interval start time {a_start} is smaller than first timestamp " - f"{timestamps[0]}, using first timestamp instead" - ) - a_start = timestamps[0] - if a_stop > timestamps[-1]: - print( - f"Interval stop time {a_stop} is larger than last timestamp " - f"{timestamps[-1]}, using last timestamp instead" - ) - a_stop = timestamps[-1] - frm, to = np.searchsorted(timestamps, (a_start, a_stop)) - if to > n_samples: - to = n_samples + frm, to = self._time_bound_check( + a_start, a_stop, timestamps, n_samples + ) indices.append((frm, to)) - shape, dtype = gsp.filter_data_fir( + + shape, _ = gsp.filter_data_fir( data, filter_coeff, axis=time_axis, @@ -458,14 +511,20 @@ def filter_data( def calc_filter_delay(self, filter_coeff): """ - :param filter_coeff: - :return: filter delay + Parameters + ---------- + filter_coeff: numpy array + + Return + ------ + filter delay: int """ return (len(filter_coeff) - 1) // 2 def create_standard_filters(self): - """Add standard filters to the Filter table including - 0-400 Hz low pass for continuous raw data -> LFP + """Add standard filters to the Filter table + + Includes 0-400 Hz low pass for continuous raw data -> LFP """ self.add_filter( "LFP 0-400 Hz", diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index 9aa184e9e..804723685 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -20,42 +20,45 @@ class IntervalList(dj.Manual): -> Session interval_list_name: varchar(200) # descriptive name of this interval list --- - valid_times: longblob # numpy array with start and end times for each interval + valid_times: longblob # numpy array with start/end times for each interval """ @classmethod def insert_from_nwbfile(cls, nwbf, *, nwb_file_name): - """Add each entry in the NWB file epochs table to the IntervalList table. + """Add each entry in the NWB file epochs table to the IntervalList. - The interval list name for each epoch is set to the first tag for the epoch. - If the epoch has no tags, then 'interval_x' will be used as the interval list name, where x is the index - (0-indexed) of the epoch in the epochs table. - The start time and stop time of the epoch are stored in the valid_times field as a numpy array of - [start time, stop time] for each epoch. + The interval list name for each epoch is set to the first tag for the + epoch. If the epoch has no tags, then 'interval_x' will be used as the + interval list name, where x is the index (0-indexed) of the epoch in the + epochs table. The start time and stop time of the epoch are stored in + the valid_times field as a numpy array of [start time, stop time] for + each epoch. Parameters ---------- nwbf : pynwb.NWBFile The source NWB file object. nwb_file_name : str - The file name of the NWB file, used as a primary key to the Session table. + The file name of the NWB file, used as a primary key to the Session + table. """ if nwbf.epochs is None: print("No epochs found in NWB file.") return + epochs = nwbf.epochs.to_dataframe() - for epoch_index, epoch_data in epochs.iterrows(): - epoch_dict = dict() - epoch_dict["nwb_file_name"] = nwb_file_name - if epoch_data.tags[0]: - epoch_dict["interval_list_name"] = epoch_data.tags[0] - else: - epoch_dict["interval_list_name"] = "interval_" + str( - epoch_index - ) - epoch_dict["valid_times"] = np.asarray( - [[epoch_data.start_time, epoch_data.stop_time]] - ) + + for _, epoch_data in epochs.iterrows(): + epoch_dict = { + "nwb_file_name": nwb_file_name, + "interval_list_name": epoch_data[1].tags[0] + if epoch_data[1].tags + else f"interval_{epoch_data[0]}", + "valid_times": np.asarray( + [[epoch_data[1].start_time, epoch_data[1].stop_time]] + ), + } + cls.insert1(epoch_dict, skip_duplicates=True) def plot_intervals(self, figsize=(20, 5)): @@ -145,7 +148,7 @@ def intervals_by_length(interval_list, min_length=0.0, max_length=1e10): Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval. Unit is seconds. + Each element is (start time, stop time), i.e. an interval in seconds. min_length : float, optional Minimum interval length in seconds. Defaults to 0.0. max_length : float, optional @@ -158,12 +161,12 @@ def intervals_by_length(interval_list, min_length=0.0, max_length=1e10): def interval_list_contains_ind(interval_list, timestamps): - """Find indices of a list of timestamps that are contained in an interval list. + """Find indices of list of timestamps contained in an interval list. Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval. Unit is seconds. + Each element is (start time, stop time), i.e. an interval in seconds. timestamps : array_like """ ind = [] @@ -184,7 +187,7 @@ def interval_list_contains(interval_list, timestamps): Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval. Unit is seconds. + Each element is (start time, stop time), i.e. an interval in seconds. timestamps : array_like """ ind = [] @@ -200,28 +203,17 @@ def interval_list_contains(interval_list, timestamps): def interval_list_excludes_ind(interval_list, timestamps): - """Find indices of a list of timestamps that are not contained in an interval list. + """Find indices of timestamps that are not contained in an interval list. Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval. Unit is seconds. + Each element is (start time, stop time), i.e. an interval in seconds. timestamps : array_like """ contained_inds = interval_list_contains_ind(interval_list, timestamps) return np.setdiff1d(np.arange(len(timestamps)), contained_inds) - # # add the first and last times to the list and creat a list of invalid intervals - # valid_times_list = np.ndarray.ravel(interval_list).tolist() - # valid_times_list.insert(0, timestamps[0] - 0.00001) - # valid_times_list.append(timestamps[-1] + 0.001) - # invalid_times = np.array(valid_times_list).reshape(-1, 2) - # # add the first and last timestamp indices - # ind = [] - # for invalid_time in invalid_times: - # ind += np.ravel(np.argwhere(np.logical_and(timestamps > invalid_time[0], - # timestamps < invalid_time[1]))).tolist() - # return np.asarray(ind) def interval_list_excludes(interval_list, timestamps): @@ -230,22 +222,24 @@ def interval_list_excludes(interval_list, timestamps): Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval. Unit is seconds. + Each element is (start time, stop time), i.e. an interval in seconds. timestamps : array_like """ contained_times = interval_list_contains(interval_list, timestamps) return np.setdiff1d(timestamps, contained_times) - # # add the first and last times to the list and creat a list of invalid intervals - # valid_times_list = np.ravel(valid_times).tolist() - # valid_times_list.insert(0, timestamps[0] - 0.00001) - # valid_times_list.append(timestamps[-1] + 0.00001) - # invalid_times = np.array(valid_times_list).reshape(-1, 2) - # # add the first and last timestamp indices - # ind = [] - # for invalid_time in invalid_times: - # ind += np.ravel(np.argwhere(np.logical_and(timestamps > invalid_time[0], - # timestamps < invalid_time[1]))).tolist() - # return timestamps[ind] + + +def consolidate_intervals(interval_list): + if interval_list.ndim == 1: + interval_list = np.expand_dims(interval_list, 0) + else: + interval_list = interval_list[np.argsort(interval_list[:, 0])] + interval_list = reduce(_union_concat, interval_list) + # the following check is needed in the case where the interval list is a + # single element (behavior of reduce) + if interval_list.ndim == 1: + interval_list = np.expand_dims(interval_list, 0) + return interval_list def interval_list_intersect(interval_list1, interval_list2, min_length=0): @@ -265,76 +259,51 @@ def interval_list_intersect(interval_list1, interval_list2, min_length=0): interval_list: np.array, (N,2) """ - # first, consolidate interval lists to disjoint intervals by sorting and applying union - if interval_list1.ndim == 1: - interval_list1 = np.expand_dims(interval_list1, 0) - else: - interval_list1 = interval_list1[np.argsort(interval_list1[:, 0])] - interval_list1 = reduce(_union_concat, interval_list1) - # the following check is needed in the case where the interval list is a single element (behavior of reduce) - if interval_list1.ndim == 1: - interval_list1 = np.expand_dims(interval_list1, 0) - - if interval_list2.ndim == 1: - interval_list2 = np.expand_dims(interval_list2, 0) - else: - interval_list2 = interval_list2[np.argsort(interval_list2[:, 0])] - interval_list2 = reduce(_union_concat, interval_list2) - # the following check is needed in the case where the interval list is a single element (behavior of reduce) - if interval_list2.ndim == 1: - interval_list2 = np.expand_dims(interval_list2, 0) + # Consolidate interval lists to disjoint int'ls by sorting & applying union + interval_list1 = consolidate_intervals(interval_list1) + interval_list2 = consolidate_intervals(interval_list2) # then do pairwise comparison and collect intersections - intersecting_intervals = [] - for interval2 in interval_list2: - for interval1 in interval_list1: - if _intersection(interval2, interval1) is not None: - intersecting_intervals.append( - _intersection(interval1, interval2) - ) + intersecting_intervals = [ + _intersection(interval2, interval1) + for interval2 in interval_list2 + for interval1 in interval_list1 + if _intersection(interval2, interval1) is not None + ] # if no intersection, then return an empty list if not intersecting_intervals: return [] - else: - intersecting_intervals = np.asarray(intersecting_intervals) - intersecting_intervals = intersecting_intervals[ - np.argsort(intersecting_intervals[:, 0]) - ] - return intervals_by_length( - intersecting_intervals, min_length=min_length - ) + intersecting_intervals = np.asarray(intersecting_intervals) + intersecting_intervals = intersecting_intervals[ + np.argsort(intersecting_intervals[:, 0]) + ] + + return intervals_by_length(intersecting_intervals, min_length=min_length) def _intersection(interval1, interval2): - "Takes the (set-theoretic) intersection of two intervals" - intersection = np.array( - [max([interval1[0], interval2[0]]), min([interval1[1], interval2[1]])] - ) - if intersection[1] > intersection[0]: - return intersection - else: - return None + """Takes the (set-theoretic) intersection of two intervals""" + start = max(interval1[0], interval2[0]) + end = min(interval1[1], interval2[1]) + intersection = np.array([start, end]) if end > start else None + return intersection def _union(interval1, interval2): - "Takes the (set-theoretic) union of two intervals" + """Takes the (set-theoretic) union of two intervals""" if _intersection(interval1, interval2) is None: return np.array([interval1, interval2]) - else: - return np.array( - [ - min([interval1[0], interval2[0]]), - max([interval1[1], interval2[1]]), - ] - ) + return np.array( + [min(interval1[0], interval2[0]), max(interval1[1], interval2[1])] + ) def _union_concat(interval_list, interval): - """Compares the last interval of the interval list to the given interval and - * takes their union if overlapping - * concatenates the interval to the interval list if not + """Compare last interval of interval list to given interval. + + If overlapping, take union. If not, concatenate interval to interval list. Recursively called with `reduce`. """ @@ -344,27 +313,23 @@ def _union_concat(interval_list, interval): interval = np.expand_dims(interval, 0) x = _union(interval_list[-1], interval[0]) - if x.ndim == 1: - x = np.expand_dims(x, 0) + x = np.expand_dims(x, 0) if x.ndim == 1 else x + return np.concatenate((interval_list[:-1], x), axis=0) def union_adjacent_index(interval1, interval2): - """unions two intervals that are adjacent in index + """Union index-adjacent intervals. If not adjacent, just concatenate. + e.g. [a,b] and [b+1, c] is converted to [a,c] - if not adjacent, just concatenates interval2 at the end of interval1 Parameters ---------- interval1 : np.array - [description] interval2 : np.array - [description] """ - if interval1.ndim == 1: - interval1 = np.expand_dims(interval1, 0) - if interval2.ndim == 1: - interval2 = np.expand_dims(interval2, 0) + interval1 = np.atleast_2d(interval1) + interval2 = np.atleast_2d(interval2) if ( interval1[-1][1] + 1 == interval2[0][0] @@ -386,50 +351,63 @@ def union_adjacent_index(interval1, interval2): # TODO: test interval_list_union code +def _parallel_union(interval_list): + """Create a parallel list where 1 is start and -1 the end""" + interval_list = np.ravel(interval_list) + interval_list_start_end = np.ones(interval_list.shape) + interval_list_start_end[1::2] = -1 + return interval_list, interval_list_start_end + + def interval_list_union( - interval_list1, interval_list2, min_length=0.0, max_length=1e10 -): + interval_list1: np.ndarray, + interval_list2: np.ndarray, + min_length: float = 0.0, + max_length: float = 1e10, +) -> np.ndarray: """Finds the union (all times in one or both) for two interval lists - :param interval_list1: The first interval list - :type interval_list1: numpy array of intervals [start, stop] - :param interval_list2: The second interval list - :type interval_list2: numpy array of intervals [start, stop] - :param min_length: optional minimum length of interval for inclusion in output, default 0.0 - :type min_length: float - :param max_length: optional maximum length of interval for inclusion in output, default 1e10 - :type max_length: float - :return: interval_list - :rtype: numpy array of intervals [start, stop] + Parameters + ---------- + interval_list1 : np.ndarray + The first interval list [start, stop] + interval_list2 : np.ndarray + The second interval list [start, stop] + min_length : float, optional + Minimum length of interval for inclusion in output, default 0.0 + max_length : float, optional + Maximum length of interval for inclusion in output, default 1e10 + + Returns + ------- + np.ndarray + Array of intervals [start, stop] """ - # return np.array([min(interval_list1[0],interval_list2[0]), - # max(interval_list1[1],interval_list2[1])]) - interval_list1 = np.ravel(interval_list1) - # create a parallel list where 1 indicates the start and -1 the end of an interval - interval_list1_start_end = np.ones(interval_list1.shape) - interval_list1_start_end[1::2] = -1 - - interval_list2 = np.ravel(interval_list2) - # create a parallel list for the second interval where 1 indicates the start and -1 the end of an interval - interval_list2_start_end = np.ones(interval_list2.shape) - interval_list2_start_end[1::2] = -1 - - # concatenate the two lists so we can resort the intervals and apply the same sorting to the start-end arrays - combined_intervals = np.concatenate((interval_list1, interval_list2)) - ss = np.concatenate((interval_list1_start_end, interval_list2_start_end)) + + il1, il1_start_end = _parallel_union(interval_list1) + il2, il2_start_end = _parallel_union(interval_list2) + + # Concatenate the two lists so we can resort the intervals and apply the + # same sorting to the start-end arrays + combined_intervals = np.concatenate((il1, il2)) + ss = np.concatenate((il1_start_end, il2_start_end)) sort_ind = np.argsort(combined_intervals) combined_intervals = combined_intervals[sort_ind] - # a cumulative sum of 1 indicates the beginning of a joint interval; a cumulative sum of 0 indicates the end + + # a cumulative sum of 1 indicates the beginning of a joint interval; a + # cumulative sum of 0 indicates the end union_starts = np.ravel(np.array(np.where(np.cumsum(ss[sort_ind]) == 1))) union_stops = np.ravel(np.array(np.where(np.cumsum(ss[sort_ind]) == 0))) - union = [] - for start, stop in zip(union_starts, union_stops): - union.append([combined_intervals[start], combined_intervals[stop]]) + union = [ + [combined_intervals[start], combined_intervals[stop]] + for start, stop in zip(union_starts, union_stops) + ] + return np.asarray(union) def interval_list_censor(interval_list, timestamps): - """returns a new interval list that starts and ends at the first and last timestamp + """Returns new interval list that starts/ends at first/last timestamp Parameters ---------- @@ -442,9 +420,10 @@ def interval_list_censor(interval_list, timestamps): interval_list (numpy array of intervals [start, stop]) """ # check that all timestamps are in the interval list - assert len(interval_list_contains_ind(interval_list, timestamps)) == len( + if len(interval_list_contains_ind(interval_list, timestamps)) != len( timestamps - ), "interval_list must contain all timestamps" + ): + raise ValueError("Interval_list must contain all timestamps") timestamps_interval = np.asarray([[timestamps[0], timestamps[-1]]]) return interval_list_intersect(interval_list, timestamps_interval) @@ -452,6 +431,7 @@ def interval_list_censor(interval_list, timestamps): def interval_from_inds(list_frames): """Converts a list of indices to a list of intervals. + e.g. [2,3,4,6,7,8,9,10] -> [[2,4],[6,10]] Parameters diff --git a/src/spyglass/common/common_lab.py b/src/spyglass/common/common_lab.py index 56333b679..3088e1906 100644 --- a/src/spyglass/common/common_lab.py +++ b/src/spyglass/common/common_lab.py @@ -25,8 +25,8 @@ class LabMemberInfo(dj.Part): # Information about lab member in the context of Frank lab network -> LabMember --- - google_user_name: varchar(200) # used for permission to curate - datajoint_user_name = "": varchar(200) # used for permission to delete entries + google_user_name: varchar(200) # For permission to curate + datajoint_user_name = "": varchar(200) # For permission to delete ns """ @classmethod @@ -108,9 +108,10 @@ def create_new_team( team_description: str The description of the team. """ - labteam_dict = dict() - labteam_dict["team_name"] = team_name - labteam_dict["team_description"] = team_description + labteam_dict = { + "team_name": team_name, + "team_description": team_description, + } cls.insert1(labteam_dict, skip_duplicates=True) for team_member in team_members: @@ -120,12 +121,13 @@ def create_new_team( ).fetch("google_user_name") if not query: print( - f"Please add the Google user ID for {team_member} in the " - + "LabMember.LabMemberInfo table to help manage permissions." + f"Please add the Google user ID for {team_member} in " + + "LabMember.LabMemberInfo to help manage permissions." ) - labteammember_dict = dict() - labteammember_dict["team_name"] = team_name - labteammember_dict["lab_member_name"] = team_member + labteammember_dict = { + "team_name": team_name, + "lab_member_name": team_member, + } cls.LabTeamMember.insert1(labteammember_dict, skip_duplicates=True) @@ -133,7 +135,6 @@ def create_new_team( class Institution(dj.Manual): definition = """ institution_name: varchar(80) - --- """ @classmethod @@ -148,6 +149,7 @@ def insert_from_nwbfile(cls, nwbf): if nwbf.institution is None: print("No institution metadata found.\n") return + cls.insert1( dict(institution_name=nwbf.institution), skip_duplicates=True ) @@ -157,7 +159,6 @@ def insert_from_nwbfile(cls, nwbf): class Lab(dj.Manual): definition = """ lab_name: varchar(80) - --- """ @classmethod @@ -198,5 +199,7 @@ def decompose_name(full_name: str) -> tuple: last, first = full_name.title().split(", ") else: first, last = full_name.title().split(" ") + full = f"{first} {last}" + return full, first, last diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 493e98962..4ad81e178 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -1,5 +1,4 @@ import os -import pathlib import random import stat import string @@ -12,7 +11,9 @@ import spikeinterface as si from hdmf.common import DynamicTable -from ..settings import load_config +import kachery_cloud as kcl + +from ..settings import raw_dir, analysis_dir from ..utils.dj_helper_fn import get_child_tables from ..utils.nwb_helper_fn import get_electrode_indices, get_nwb_file @@ -47,8 +48,8 @@ class Nwbfile(dj.Manual): nwb_file_abs_path: filepath@raw INDEX (nwb_file_abs_path) """ - # NOTE the INDEX above is implicit from filepath@... above but needs to be explicit - # so that alter() can work + # NOTE the INDEX above is implicit from filepath@... above but needs to be + # explicit so that alter() can work @classmethod def insert_from_relative_file_name(cls, nwb_file_name): @@ -78,7 +79,8 @@ def get_abs_path(nwb_file_name): Parameters ---------- nwb_file_name : str - The name of an NWB file that has been inserted into the Nwbfile() schema. + The name of an NWB file that has been inserted into the Nwbfile() + schema. Returns ------- @@ -86,18 +88,19 @@ def get_abs_path(nwb_file_name): The absolute path for the given file name. """ - return load_config()["SPYGLASS_RAW_DIR"] + "/" + nwb_file_name + return Path(raw_dir) / nwb_file_name @staticmethod def add_to_lock(nwb_file_name): - """Add the specified NWB file to the file with the list of NWB files to be locked. + """Add given file to the list of NWB files to be locked. The NWB_LOCK_FILE environment variable must be set. Parameters ---------- nwb_file_name : str - The name of an NWB file that has been inserted into the Nwbfile() schema. + The name of an NWB file that has been inserted into the Nwbfile() + schema. """ key = {"nwb_file_name": nwb_file_name} # check to make sure the file exists @@ -113,34 +116,36 @@ def add_to_lock(nwb_file_name): def cleanup(delete_files=False): """Remove the filepath entries for NWB files that are not in use. - This does not delete the files themselves unless delete_files=True is specified - Run this after deleting the Nwbfile() entries themselves. + This does not delete the files themselves unless delete_files=True is + specified Run this after deleting the Nwbfile() entries themselves. """ schema.external["raw"].delete(delete_external_files=delete_files) -# TODO: add_to_kachery will not work because we can't update the entry after it's been used in another table. -# We therefore need another way to keep track of the +# TODO: add_to_kachery will not work because we can't update the entry after +# it's been used in another table. We therefore need another way to keep track +# of the @schema class AnalysisNwbfile(dj.Manual): definition = """ - # Table for holding the NWB files that contain results of analysis, such as spike sorting. + # NWB files that contain results of analysis, such as spike sorting. analysis_file_name: varchar(255) # name of the file --- - -> Nwbfile # name of the parent NWB file. Used for naming and metadata copy + -> Nwbfile # Parent NWB filename. For + # naming and metadata copy analysis_file_abs_path: filepath@analysis # the full path to the file - analysis_file_description = "": varchar(2000) # an optional description of this analysis - analysis_parameters = NULL: blob # additional relevant parameters. Currently used only for analyses - # that span multiple NWB files + analysis_file_description = "": varchar(2000) # Optional description + analysis_parameters = NULL: blob # additional relevant params. INDEX (analysis_file_abs_path) """ - # NOTE the INDEX above is implicit from filepath@... above but needs to be explicit - # so that alter() can work + # NOTE the INDEX above is implicit from filepath@... above but needs to be + # explicit so that alter() can work def create(self, nwb_file_name): - """Open the NWB file, create a copy, write the copy to disk and return the name of the new file. + """Open NWB file, create copy, write copy to disk, return new file name. - Note that this does NOT add the file to the schema; that needs to be done after data are written to it. + Note that this does NOT add the file to the schema; that needs to be + done after data are written to it. Parameters ---------- @@ -186,33 +191,28 @@ def create(self, nwb_file_name): @classmethod def __get_new_file_name(cls, nwb_file_name): - # each file ends with a random string of 10 digits, so we generate that string and redo if by some miracle - # it's already there - file_in_table = True - while file_in_table: + while True: + random_string = "".join( + random.choices(string.ascii_uppercase + string.digits, k=10) + ) analysis_file_name = ( - os.path.splitext(nwb_file_name)[0] - + "".join( - random.choices(string.ascii_uppercase + string.digits, k=10) - ) - + ".nwb" + f"{os.path.splitext(nwb_file_name)[0]}{random_string}.nwb" ) - file_in_table = AnalysisNwbfile & { - "analysis_file_name": analysis_file_name - } - - return analysis_file_name + if not AnalysisNwbfile & {"analysis_file_name": analysis_file_name}: + return analysis_file_name @classmethod def __get_analysis_file_dir(cls, analysis_file_name: str): - # strip off everything after and including the final underscore and return the result + # strip off everything after and including the final underscore and + # return the result return analysis_file_name[0 : analysis_file_name.rfind("_")] @classmethod def copy(cls, nwb_file_name): """Make a copy of an analysis NWB file. - Note that this does NOT add the file to the schema; that needs to be done after data are written to it. + Note that this does NOT add the file to the schema; that needs to be + done after data are written to it. Parameters ---------- @@ -267,35 +267,30 @@ def add(self, nwb_file_name, analysis_file_name): @staticmethod def get_abs_path(analysis_nwb_file_name): - """Return the absolute path for a stored analysis NWB file given just the file name. - - The SPYGLASS_BASE_DIR environment variable must be set. + """Return absolute path for a stored analysis NWB file given file name. Parameters ---------- analysis_nwb_file_name : str - The name of the NWB file that has been inserted into the AnalysisNwbfile() schema + The name of the NWB file that has been inserted into the + AnalysisNwbfile() schema Returns ------- analysis_nwb_file_abspath : str The absolute path for the given file name. """ - base_dir = pathlib.Path(os.getenv("SPYGLASS_BASE_DIR", None)) - assert ( - base_dir is not None - ), "You must set SPYGLASS_BASE_DIR environment variable." + analysis_path = Path(analysis_dir) # see if the file exists and is stored in the base analysis dir - test_path = str(base_dir / "analysis" / analysis_nwb_file_name) + test_path = analysis_path / analysis_nwb_file_name if os.path.exists(test_path): return test_path else: # use the new path analysis_file_base_path = ( - base_dir - / "analysis" + analysis_path / AnalysisNwbfile.__get_analysis_file_dir( analysis_nwb_file_name ) @@ -307,9 +302,9 @@ def get_abs_path(analysis_nwb_file_name): def add_nwb_object( self, analysis_file_name, nwb_object, table_name="pandas_table" ): - # TODO: change to add_object with checks for object type and a name parameter, which should be specified if - # it is not an NWB container - """Add an NWB object to the analysis file in the scratch area and returns the NWB object ID + # TODO: change to add_object with checks for object type and a name + # parameter, which should be specified if it is not an NWB container + """Add NWB object to analysis file in 'scratch', return NWB object ID Parameters ---------- @@ -352,7 +347,7 @@ def add_units( metrics=None, units_waveforms=None, labels=None, - ): + ) -> tuple[str, str]: """Add units to analysis NWB file Parameters @@ -375,7 +370,8 @@ def add_units( Returns ------- units_object_id, waveforms_object_id : str, str - The NWB object id of the Units object and the object id of the waveforms object ('' if None) + The NWB object id of the Units object and the object id of the + waveforms object (empty strings if None) """ with pynwb.NWBHDF5IO( path=self.get_abs_path(analysis_file_name), @@ -384,70 +380,72 @@ def add_units( ) as io: nwbf = io.read() sort_intervals = list() - if len(units.keys()): - # Add spike times and valid time range for the sort - for id in units.keys(): - nwbf.add_unit( - spike_times=units[id], - id=id, - # waveform_mean = units_templates[id], - obs_intervals=units_valid_times[id], - ) - sort_intervals.append(units_sort_interval[id]) - # Add a column for the sort interval (subset of valid time) + + if not units.keys(): + return "", "" + + # Add spike times and valid time range for the sort + for id in units: + nwbf.add_unit( + spike_times=units[id], + id=id, + # waveform_mean = units_templates[id], + obs_intervals=units_valid_times[id], + ) + sort_intervals.append(units_sort_interval[id]) + + # Add a column for the sort interval (subset of valid time) + nwbf.add_unit_column( + name="sort_interval", + description="the interval used for spike sorting", + data=sort_intervals, + ) + + # If metrics were specified, add one column per metric + if metrics: + for metric, values in metrics.items(): + if values: + unit_ids = np.array(list(values.keys())) + metric_values = np.array(list(values.values())) + # sort by unit_ids and apply that sorting to values to + # ensure that things go in the right order + metric_values = metric_values[np.argsort(unit_ids)] + print(f"Adding metric {metric} : {metric_values}") + nwbf.add_unit_column( + name=metric, + description=f"{metric} metric", + data=metric_values, + ) + + if labels: + unit_ids = np.array(list(units.keys())) + for unit in unit_ids: + if unit not in labels: + labels[unit] = "" + label_values = np.array(list(labels.values())) + label_values = label_values[np.argsort(unit_ids)].tolist() nwbf.add_unit_column( - name="sort_interval", - description="the interval used for spike sorting", - data=sort_intervals, + name="label", + description="label given during curation", + data=label_values, ) - # If metrics were specified, add one column per metric - if metrics is not None: - for metric in metrics: - if metrics[metric]: - unit_ids = np.array(list(metrics[metric].keys())) - metric_values = np.array( - list(metrics[metric].values()) - ) - # sort by unit_ids and apply that sorting to values to ensure that things go in the right order - metric_values = metric_values[np.argsort(unit_ids)] - print(f"Adding metric {metric} : {metric_values}") - nwbf.add_unit_column( - name=metric, - description=f"{metric} metric", - data=metric_values, - ) - if labels is not None: - unit_ids = np.array(list(units.keys())) - for unit in unit_ids: - if unit not in labels: - labels[unit] = "" - label_values = np.array(list(labels.values())) - label_values = label_values[np.argsort(unit_ids)].tolist() - nwbf.add_unit_column( - name="label", - description="label given during curation", - data=label_values, - ) - # If the waveforms were specified, add them as a dataframe to scratch - waveforms_object_id = "" - if units_waveforms is not None: - waveforms_df = pd.DataFrame.from_dict( - units_waveforms, orient="index" - ) - waveforms_df.columns = ["waveforms"] - nwbf.add_scratch( - waveforms_df, - name="units_waveforms", - notes="spike waveforms for each unit", - ) - waveforms_object_id = nwbf.scratch[ - "units_waveforms" - ].object_id - io.write(nwbf) - return nwbf.units.object_id, waveforms_object_id - else: - return "" + # If the waveforms were specified, add as a dataframe to scratch + waveforms_object_id = "" + if units_waveforms: + waveforms_df = pd.DataFrame.from_dict( + units_waveforms, orient="index" + ) + waveforms_df.columns = ["waveforms"] + nwbf.add_scratch( + waveforms_df, + name="units_waveforms", + notes="spike waveforms for each unit", + ) + waveforms_object_id = nwbf.scratch["units_waveforms"].object_id + + io.write(nwbf) + return nwbf.units.object_id, waveforms_object_id def add_units_waveforms( self, @@ -495,9 +493,10 @@ def add_units_waveforms( ) # The following is a rough sketch of AnalysisNwbfile().add_waveforms - # analysis_file_name = AnalysisNwbfile().create(key['nwb_file_name']) - # or - # nwbfile = pynwb.NWBFile(...) + # analysis_file_name = + # AnalysisNwbfile().create(key['nwb_file_name']) + # or nwbfile = pynwb.NWBFile(...) + # # (channels, spikes, samples) # wfs = [ # [ # elec 1 @@ -517,11 +516,14 @@ def add_units_waveforms( # [1, 2, 3] # spike 4 # ] # ] - # elecs = ... # DynamicTableRegion referring to three electrodes (rows) of the electrodes table - # nwbfile.add_unit(spike_times=[1, 2, 3], electrodes=elecs, waveforms=wfs) + # elecs = ... # DynamicTableRegion referring to three electrodes + # # (rows) of the electrodes table + # nwbfile.add_unit( + # spike_times=[1, 2, 3], electrodes=elecs, waveforms=wfs + # ) # If metrics were specified, add one column per metric - if metrics is not None: + if metrics: for metric_name, metric_dict in metrics.items(): print(f"Adding metric {metric_name} : {metric_dict}") metric_data = metric_dict.values().to_list() @@ -530,7 +532,7 @@ def add_units_waveforms( description=metric_name, data=metric_data, ) - if labels is not None: + if labels: nwbf.add_unit_column( name="label", description="label given during curation", @@ -578,7 +580,7 @@ def add_units_metrics(self, analysis_file_name, metrics): @classmethod def get_electrode_indices(cls, analysis_file_name, electrode_ids): - """Given an analysis NWB file name, returns the indices of the specified electrode_ids. + """Given analysis NWB file name & electrode IDs, return indices Parameters ---------- @@ -590,7 +592,7 @@ def get_electrode_indices(cls, analysis_file_name, electrode_ids): Returns ------- electrode_indices : numpy array - Array of indices in the electrodes table for the given electrode IDs. + Array of indices in the electrodes table for the given IDs. """ nwbf = get_nwb_file(cls.get_abs_path(analysis_file_name)) return get_electrode_indices(nwbf.electrodes, electrode_ids) @@ -599,8 +601,8 @@ def get_electrode_indices(cls, analysis_file_name, electrode_ids): def cleanup(delete_files=False): """Remove the filepath entries for NWB files that are not in use. - Does not delete the files themselves unless delete_files=True is specified. - Run this after deleting the Nwbfile() entries themselves. + Does not delete the files themselves unless delete_files=True is + specified. Run this after deleting the Nwbfile() entries themselves. Parameters ---------- @@ -618,7 +620,8 @@ def nightly_cleanup(): # during times when no other transactions are in progress. AnalysisNwbfile.cleanup(True) - # also check to see whether there are directories in the spikesorting folder with this + # also check to see whether there are directories in the spikesorting + # folder with this @schema @@ -631,7 +634,7 @@ class NwbfileKachery(dj.Computed): def make(self, key): print(f'Linking {key["nwb_file_name"]} and storing in kachery...') - key["nwb_file_uri"] = kc.link_file( + key["nwb_file_uri"] = kcl.link_file( Nwbfile().get_abs_path(key["nwb_file_name"]) ) self.insert1(key) @@ -647,7 +650,7 @@ class AnalysisNwbfileKachery(dj.Computed): def make(self, key): print(f'Linking {key["analysis_file_name"]} and storing in kachery...') - key["analysis_file_uri"] = kc.link_file( + key["analysis_file_uri"] = kcl.link_file( AnalysisNwbfile().get_abs_path(key["analysis_file_name"]) ) self.insert1(key) diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index d7db69985..0634ba4bb 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -26,7 +26,7 @@ from ..settings import raw_dir from ..utils.dj_helper_fn import fetch_nwb from .common_behav import RawPosition, VideoFile -from .common_interval import IntervalList +from .common_interval import IntervalList # noqa: F401 from .common_nwbfile import AnalysisNwbfile schema = dj.schema("common_position") @@ -42,13 +42,14 @@ class PositionInfoParameters(dj.Lookup): --- max_separation = 9.0 : float # max distance (in cm) between head LEDs max_speed = 300.0 : float # max speed (in cm / s) of animal - position_smoothing_duration = 0.125 : float # size of moving window (in seconds) - speed_smoothing_std_dev = 0.100 : float # smoothing standard deviation (in seconds) - head_orient_smoothing_std_dev = 0.001 : float # smoothing std deviation (in seconds) - led1_is_front = 1 : int # first LED is front LED and second is back LED, else first LED is back + position_smoothing_duration = 0.125 : float # size of moving window in s + speed_smoothing_std_dev = 0.100 : float # smoothing standard deviation in s + head_orient_smoothing_std_dev = 0.001 : float # smoothing std deviation in s + led1_is_front = 1 : int # 1 = first LED is front LED and second is back LED is_upsampled = 0 : int # upsample the position to higher sampling rate upsampling_sampling_rate = NULL : float # The rate to be upsampled to - upsampling_interpolation_method = linear : varchar(80) # see pandas.DataFrame.interpolation for list of methods + upsampling_interpolation_method = linear : varchar(80) + # see pandas.DataFrame.interpolation for list of methods """ @@ -106,16 +107,8 @@ def make(self, key): # calculate the processed position spatial_series = raw_position["raw_position"] position_info = self.calculate_position_info_from_spatial_series( - spatial_series, - position_info_parameters["max_separation"], - position_info_parameters["max_speed"], - position_info_parameters["speed_smoothing_std_dev"], - position_info_parameters["position_smoothing_duration"], - position_info_parameters["head_orient_smoothing_std_dev"], - position_info_parameters["led1_is_front"], - position_info_parameters["is_upsampled"], - position_info_parameters["upsampling_sampling_rate"], - position_info_parameters["upsampling_interpolation_method"], + spatial_series=spatial_series, + **position_info_parameters, ) # create nwb objects for insertion into analysis nwb file @@ -389,14 +382,19 @@ def fetch1_dataframe(self): class LinearizationParameters(dj.Lookup): """Choose whether to use an HMM to linearize position. This can help when the eucledian distances between separate arms are too close and the previous - position has some information about which arm the animal is on.""" + position has some information about which arm the animal is on. + + route_euclidean_distance_scaling : float + How much to prefer route distances between successive time points that + are closer to the euclidean distance. Smaller numbers mean the route + distance is more likely to be close to the euclidean distance. + """ definition = """ linearization_param_name : varchar(80) # name for this set of parameters --- use_hmm = 0 : int # use HMM to determine linearization - # How much to prefer route distances between successive time points that are closer to the euclidean distance. Smaller numbers mean the route distance is more likely to be close to the euclidean distance. - route_euclidean_distance_scaling = 1.0 : float + route_euclidean_distance_scaling = 1.0 : float # see docstring sensor_std_dev = 5.0 : float # Uncertainty of position sensor (in cm). # Biases the transition matrix to prefer the current track segment. diagonal_bias = 0.5 : float @@ -411,11 +409,11 @@ class TrackGraph(dj.Manual): definition = """ track_graph_name : varchar(80) ---- - environment : varchar(80) # Type of Environment - node_positions : blob # 2D position of track_graph nodes, shape (n_nodes, 2) - edges: blob # shape (n_edges, 2) - linear_edge_order : blob # order of track graph edges in the linear space, shape (n_edges, 2) - linear_edge_spacing : blob # amount of space between edges in the linear space, shape (n_edges,) + environment : varchar(80) # Type of Environment + node_positions : blob # 2D position of graph nodes (n_nodes, 2) + edges: blob # shape (n_edges, 2) + linear_edge_order : blob # order of edges in linspace (n_edges, 2) + linear_edge_spacing : blob # space between edges in the linspace (n_edges,) """ def get_networkx_track_graph(self, track_graph_parameters=None): @@ -581,7 +579,8 @@ def __init__( ax.imshow(frame, picker=True) ax.set_title( "Left click to place node.\nRight click to remove node." - "\nShift+Left click to clear nodes.\nCntrl+Left click two nodes to place an edge" + "\nShift+Left click to clear nodes.\nCntrl+Left click two nodes" + " to place an edge" ) self.connect() @@ -696,20 +695,18 @@ def make(self, key): M_TO_CM = 100 print("Loading position data...") + this_file = {"nwb_file_name": key["nwb_file_name"]} + this_interval = {"interval_list_name": key["interval_list_name"]} + this_param = { + "position_info_param_name": key["position_info_param_name"] + } + raw_position_df = ( - RawPosition() - & { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": key["interval_list_name"], - } + RawPosition() & this_file & this_interval ).fetch1_dataframe() + position_info_df = ( - IntervalPositionInfo() - & { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": key["interval_list_name"], - "position_info_param_name": key["position_info_param_name"], - } + IntervalPositionInfo() & this_file & this_interval & this_param ).fetch1_dataframe() print("Loading video data...") @@ -721,11 +718,10 @@ def make(self, key): ) + 1 ) - video_info = ( - VideoFile() - & {"nwb_file_name": key["nwb_file_name"], "epoch": epoch} - ).fetch1() + video_info = (VideoFile() & this_file & {"epoch": epoch}).fetch1() + io = pynwb.NWBHDF5IO(raw_dir() + video_info["nwb_file_name"], "r") + nwb_file = io.read() nwb_video = nwb_file.objects[video_info["video_file_object_id"]] video_filename = nwb_video.external_file.value[0] @@ -740,6 +736,7 @@ def make(self, key): "red": np.asarray(raw_position_df[["xloc", "yloc"]]), "green": np.asarray(raw_position_df[["xloc2", "yloc2"]]), } + head_position_mean = np.asarray( position_info_df[["head_position_x", "head_position_y"]] ) @@ -779,7 +776,32 @@ def convert_to_pixels(data, frame_size, cm_to_pixels=1.0): return data / cm_to_pixels @staticmethod - def fill_nan(variable, video_time, variable_time): + def fill_nan( + variable: np.ndarray, video_time: np.ndarray, variable_time: np.ndarray + ) -> np.ndarray: + """ + Fills missing vals in variable with NaN based on the video_time array. + + Parameters + ---------- + variable : np.ndarray + The variable array with missing values. + video_time : np.ndarray + The array of video times. + variable_time : np.ndarray + The array of variable times. + + Returns + ------- + filled_variable : np.ndarray + The variable array with NaN values filled based on video time. + + Raises + ------ + IndexError + If the variable array does not have more than one dimension. + """ + video_ind = np.digitize(variable_time, video_time[1:]) n_video_time = len(video_time) diff --git a/src/spyglass/common/common_region.py b/src/spyglass/common/common_region.py index b21e99cfa..97d8f26c8 100644 --- a/src/spyglass/common/common_region.py +++ b/src/spyglass/common/common_region.py @@ -13,16 +13,18 @@ class BrainRegion(dj.Lookup): subsubregion_name=NULL: varchar(200) # subregion within subregion """ - # TODO consider making (region_name, subregion_name, subsubregion_name) a primary key - # subregion_name='' and subsubregion_name='' will be necessary but that seems OK + # TODO consider making (region_name, subregion_name, subsubregion_name) a + # primary key subregion_name='' and subsubregion_name='' will be necessary + # but that seems OK @classmethod def fetch_add( cls, region_name, subregion_name=None, subsubregion_name=None ): - """Return the region ID for the given names, and if no match exists, first add it to the BrainRegion table. + """Return the region ID for names. If no match, add to the BrainRegion. - The combination of (region_name, subregion_name, subsubregion_name) is effectively unique, then. + The combination of (region_name, subregion_name, subsubregion_name) is + effectively unique, then. Parameters ---------- @@ -38,12 +40,10 @@ def fetch_add( region_id : int The index of the region in the BrainRegion table. """ - key = dict() - key["region_name"] = region_name - key["subregion_name"] = subregion_name - key["subsubregion_name"] = subsubregion_name - query = BrainRegion & key - if not query: - cls.insert1(key) - query = BrainRegion & key - return query.fetch1("region_id") + key = dict( + region_name=region_name, + subregion_name=subregion_name, + subsubregion_name=subsubregion_name, + ) + cls.insert1(key, skip_duplicates=True) + return (BrainRegion & key).fetch1("region_id") diff --git a/src/spyglass/spikesorting/spikesorting_recording.py b/src/spyglass/spikesorting/spikesorting_recording.py index 593ff8db1..cdad793e9 100644 --- a/src/spyglass/spikesorting/spikesorting_recording.py +++ b/src/spyglass/spikesorting/spikesorting_recording.py @@ -21,6 +21,7 @@ from ..common.common_nwbfile import Nwbfile from ..common.common_session import Session # noqa: F401 from ..utils.dj_helper_fn import dj_replace +from ..settings import recording_dir schema = dj.schema("spikesorting_recording") @@ -379,59 +380,67 @@ def make(self, key): recording = self._get_filtered_recording(key) recording_name = self._get_recording_name(key) - tmp_key = { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": recording_name, - "valid_times": sort_interval_valid_times, - } - IntervalList.insert1(tmp_key, replace=True) + # Path to files that will hold the recording extractors + recording_path = str(recording_dir / Path(recording_name)) + if os.path.exists(recording_path): + shutil.rmtree(recording_path) - # store the list of valid times for the sort - key["sort_interval_list_name"] = tmp_key["interval_list_name"] + recording.save( + folder=recording_path, chunk_duration="10000ms", n_jobs=8 + ) - # Path to files that will hold the recording extractors - recording_folder = Path(os.getenv("SPYGLASS_RECORDING_DIR")) - key["recording_path"] = str(recording_folder / Path(recording_name)) - if os.path.exists(key["recording_path"]): - shutil.rmtree(key["recording_path"]) - recording = recording.save( - folder=key["recording_path"], chunk_duration="10000ms", n_jobs=8 + IntervalList.insert1( + { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": recording_name, + "valid_times": sort_interval_valid_times, + }, + replace=True, ) - self.insert1(key) + self.insert1( + { + **key, + # store the list of valid times for the sort + "sort_interval_list_name": recording_name, + "recording_path": recording_path, + } + ) @staticmethod def _get_recording_name(key): - recording_name = ( - key["nwb_file_name"] - + "_" - + key["sort_interval_name"] - + "_" - + str(key["sort_group_id"]) - + "_" - + key["preproc_params_name"] + return "_".join( + [ + key["nwb_file_name"], + key["sort_interval_name"], + str(key["sort_group_id"]), + key["preproc_params_name"], + ] ) - return recording_name @staticmethod def _get_recording_timestamps(recording): - if recording.get_num_segments() > 1: - frames_per_segment = [0] - for i in range(recording.get_num_segments()): - frames_per_segment.append( - recording.get_num_frames(segment_index=i) - ) + num_segments = recording.get_num_segments() - cumsum_frames = np.cumsum(frames_per_segment) - total_frames = np.sum(frames_per_segment) + if num_segments <= 1: + return recording.get_times() + + frames_per_segment = [0] + [ + recording.get_num_frames(segment_index=i) + for i in range(num_segments) + ] + + cumsum_frames = np.cumsum(frames_per_segment) + total_frames = np.sum(frames_per_segment) + + timestamps = np.zeros((total_frames,)) + for i in range(num_segments): + start_index = cumsum_frames[i] + end_index = cumsum_frames[i + 1] + timestamps[start_index:end_index] = recording.get_times( + segment_index=i + ) - timestamps = np.zeros((total_frames,)) - for i in range(recording.get_num_segments()): - timestamps[ - cumsum_frames[i] : cumsum_frames[i + 1] - ] = recording.get_times(segment_index=i) - else: - timestamps = recording.get_times() return timestamps def _get_sort_interval_valid_times(self, key): @@ -456,9 +465,11 @@ def _get_sort_interval_valid_times(self, key): "sort_interval_name": key["sort_interval_name"], } ).fetch1("sort_interval") + interval_list_name = (SpikeSortingRecordingSelection & key).fetch1( "interval_list_name" ) + valid_interval_times = ( IntervalList & { @@ -466,6 +477,7 @@ def _get_sort_interval_valid_times(self, key): "interval_list_name": interval_list_name, } ).fetch1("valid_times") + valid_sort_times = interval_list_intersect( sort_interval, valid_interval_times ) diff --git a/src/spyglass/utils/nwb_helper_fn.py b/src/spyglass/utils/nwb_helper_fn.py index 58b9b4696..ac576469c 100644 --- a/src/spyglass/utils/nwb_helper_fn.py +++ b/src/spyglass/utils/nwb_helper_fn.py @@ -21,7 +21,9 @@ def get_nwb_file(nwb_file_path): """Return an NWBFile object with the given file path in read mode. - If the file is not found locally, this will check if it has been shared with kachery and if so, download it and open it. + + If the file is not found locally, this will check if it has been shared with + kachery and if so, download it and open it. Parameters ---------- @@ -40,7 +42,8 @@ def get_nwb_file(nwb_file_path): # check to see if the file exists if not os.path.exists(nwb_file_path): print( - f"NWB file {nwb_file_path} does not exist locally; checking kachery" + f"NWB file {nwb_file_path} does not exist locally; " + + "checking kachery" ) # first try the analysis files from ..sharing.sharing_kachery import AnalysisNwbfileKachery @@ -98,7 +101,9 @@ def close_nwb_files(): def get_data_interface(nwbfile, data_interface_name, data_interface_class=None): - """Search for a specified NWBDataInterface or DynamicTable in the processing modules of an NWB file. + """ + Search for a specified NWBDataInterface or DynamicTable in the processing + modules of an NWB file. Parameters ---------- @@ -107,13 +112,15 @@ def get_data_interface(nwbfile, data_interface_name, data_interface_class=None): data_interface_name : str The name of the NWBDataInterface or DynamicTable to search for. data_interface_class : type, optional - The class (or superclass) to search for. This argument helps to prevent accessing an object with the same - name but the incorrect type. Default: no restriction. + The class (or superclass) to search for. This argument helps to prevent + accessing an object with the same name but the incorrect type. Default: + no restriction. Warns ----- UserWarning - If multiple NWBDataInterface and DynamicTable objects with the matching name are found. + If multiple NWBDataInterface and DynamicTable objects with the matching + name are found. Returns ------- @@ -132,8 +139,9 @@ def get_data_interface(nwbfile, data_interface_name, data_interface_class=None): if len(ret) > 1: warnings.warn( f"Multiple data interfaces with name '{data_interface_name}' " - f"found in NWBFile with identifier {nwbfile.identifier}. Using the first one found. " - "Use the data_interface_class argument to restrict the search." + f"found in NWBFile with identifier {nwbfile.identifier}. " + "Using the first one found. Use the data_interface_class argument " + "to restrict the search." ) if len(ret) >= 1: return ret[0] @@ -168,8 +176,9 @@ def get_raw_eseries(nwbfile): def estimate_sampling_rate(timestamps, multiplier): """Estimate the sampling rate given a list of timestamps. - Assumes that the most common temporal differences between timestamps approximate the sampling rate. Note that this - can fail for very high sampling rates and irregular timestamps. + Assumes that the most common temporal differences between timestamps + approximate the sampling rate. Note that this can fail for very high + sampling rates and irregular timestamps. Parameters ---------- @@ -185,7 +194,8 @@ def estimate_sampling_rate(timestamps, multiplier): # approach: # 1. use a box car smoother and a histogram to get the modal value - # 2. identify adjacent samples as those that have a time difference < the multiplier * the modal value + # 2. identify adjacent samples as those that have a + # time difference < the multiplier * the modal value # 3. average the time differences between adjacent samples sample_diff = np.diff(timestamps[~np.isnan(timestamps)]) if len(sample_diff) < 10: @@ -196,7 +206,8 @@ def estimate_sampling_rate(timestamps, multiplier): smoother = np.ones(nsmooth) / nsmooth smooth_diff = np.convolve(sample_diff, smoother, mode="same") - # we histogram with 100 bins out to 3 * mean, which should be fine for any reasonable number of samples + # we histogram with 100 bins out to 3 * mean, which should be fine for any + # reasonable number of samples hist, bins = np.histogram( smooth_diff, bins=100, range=[0, 3 * np.mean(smooth_diff)] )