diff --git a/src/spyglass/common/common_device.py b/src/spyglass/common/common_device.py index 7849e6099..e05be3fcf 100644 --- a/src/spyglass/common/common_device.py +++ b/src/spyglass/common/common_device.py @@ -40,31 +40,35 @@ class DataAcquisitionDevice(dj.Manual): def insert_from_nwbfile(cls, nwbf, config): """Insert data acquisition devices from an NWB file. - Note that this does not link the DataAcquisitionDevices with a Session. For that, - see DataAcquisitionDeviceList. + Note that this does not link the DataAcquisitionDevices with a Session. + For that, see DataAcquisitionDeviceList. Parameters ---------- nwbf : pynwb.NWBFile The source NWB file object. config : dict - Dictionary read from a user-defined YAML file containing values to replace in the NWB file. + Dictionary read from a user-defined YAML file containing values to + replace in the NWB file. """ _, ndx_devices, _ = cls.get_all_device_names(nwbf, config) for device_name in ndx_devices: new_device_dict = dict() - # read device properties into new_device_dict from PyNWB extension device object + # read device properties into new_device_dict from PyNWB extension + # device object nwb_device_obj = ndx_devices[device_name] name = nwb_device_obj.name adc_circuit = nwb_device_obj.adc_circuit - # transform system value. check if value is in DB. if not, prompt user to add an entry or cancel. + # transform system value. check if value is in DB. if not, prompt + # user to add an entry or cancel. system = cls._add_system(nwb_device_obj.system) - # transform amplifier value. check if value is in DB. if not, prompt user to add an entry or cancel. + # transform amplifier value. check if value is in DB. if not, prompt + # user to add an entry or cancel. amplifier = cls._add_amplifier(nwb_device_obj.amplifier) # standardize how Intan is represented in the database @@ -80,29 +84,34 @@ def insert_from_nwbfile(cls, nwbf, config): if ndx_devices: print( - f"Inserted or referenced data acquisition device(s): {ndx_devices.keys()}" + "Inserted or referenced data acquisition device(s): " + + f"{ndx_devices.keys()}" ) else: print("No conforming data acquisition device metadata found.") @classmethod - def get_all_device_names(cls, nwbf, config): - """Get a list of all device names in the NWB file, after appending and overwriting by the config file. + def get_all_device_names(cls, nwbf, config) -> tuple: + """ + Return device names in the NWB file, after appending and overwriting by + the config file. Parameters ---------- nwbf : pynwb.NWBFile The source NWB file object. config : dict - Dictionary read from a user-defined YAML file containing values to replace in the NWB file. + Dictionary read from a user-defined YAML file containing values to + replace in the NWB file. Returns ------- - device_name_list : list + device_name_list : tuple List of data acquisition object names found in the NWB file. """ - # make a dict mapping device name to PyNWB device object for all devices in the NWB file that are - # of type ndx_franklab_novela.DataAcqDevice and thus have the required metadata + # make a dict mapping device name to PyNWB device object for all devices + # in the NWB file that are of type ndx_franklab_novela.DataAcqDevice and + # thus have the required metadata ndx_devices = { device_obj.name: device_obj for device_obj in nwbf.devices.values() @@ -124,10 +133,11 @@ def get_all_device_names(cls, nwbf, config): @classmethod def _add_device(cls, new_device_dict): - """Check that the information in the NWB file and the database for the given device name match perfectly. + """Ensure match betweent NWB file info & database entry. - If no DataAcquisitionDevice with the given name exists in the database, check whether the user wants to add - a new entry instead of referencing an existing entry. If so, return. If not, raise an exception. + If no DataAcquisitionDevice with the given name exists in the database, + check whether the user wants to add a new entry instead of referencing + an existing entry. If so, return. If not, raise an exception. Parameters ---------- @@ -137,48 +147,48 @@ def _add_device(cls, new_device_dict): Raises ------ PopulateException - If user chooses not to add a device to the database when prompted or if the device properties from the - NWB file do not match the properties of the corresponding database entry. + If user chooses not to add a device to the database when prompted or + if the device properties from the NWB file do not match the + properties of the corresponding database entry. """ name = new_device_dict["data_acquisition_device_name"] all_values = DataAcquisitionDevice.fetch( "data_acquisition_device_name" ).tolist() if name not in all_values: - # no entry with the same name exists, prompt user about adding a new entry + # no entry with the same name exists, prompt user to add a new entry print( - f"\nData acquisition device '{name}' was not found in the database. " - f"The current values are: {all_values}. " - "Please ensure that the device you want to add does not already " - "exist in the database under a different name or spelling. " + f"\nData acquisition device '{name}' was not found in the " + f"database. The current values are: {all_values}. " + "Please ensure that the device you want to add does not already" + " exist in the database under a different name or spelling. " "If you want to use an existing device in the database, " - "please change the corresponding Device object in the NWB file. " - "Entering 'N' will raise an exception." - ) - val = input( - f"Do you want to add data acquisition device '{name}' to the database? (y/N)" + "please change the corresponding Device object in the NWB file." + " Entering 'N' will raise an exception." ) + to_db = " to the database" + val = input(f"Add data acquisition device '{name}'{to_db}? (y/N)") if val.lower() in ["y", "yes"]: cls.insert1(new_device_dict, skip_duplicates=True) return raise PopulateException( - f"User chose not to add data acquisition device '{name}' to the database." + f"User chose not to add device '{name}'{to_db}." ) - # effectively else (entry exists) - # check whether the values provided match the values stored in the database + # Check if values provided match the values stored in the database db_dict = ( DataAcquisitionDevice & {"data_acquisition_device_name": name} ).fetch1() if db_dict != new_device_dict: raise PopulateException( - f"Data acquisition device properties of PyNWB Device object with name '{name}': " - f"{new_device_dict} do not match properties of the corresponding database entry: {db_dict}." + "Data acquisition device properties of PyNWB Device object " + + f"with name '{name}': {new_device_dict} do not match " + f"properties of the corresponding database entry: {db_dict}." ) @classmethod def _add_system(cls, system): - """Check the system value. If it is not in the database, prompt the user to add the value to the database. + """Check the system value. If not in the db, prompt user to add it. This method also renames the system value "MCU" to "SpikeGadgets". @@ -190,7 +200,8 @@ def _add_system(cls, system): Raises ------ PopulateException - If user chooses not to add a device system value to the database when prompted. + If user chooses not to add a device system value to the database + when prompted. Returns ------- @@ -205,29 +216,31 @@ def _add_system(cls, system): ).tolist() if system not in all_values: print( - f"\nData acquisition device system '{system}' was not found in the database. " - f"The current values are: {all_values}. " - "Please ensure that the system you want to add does not already " - "exist in the database under a different name or spelling. " + f"\nData acquisition device system '{system}' was not found in" + f" the database. The current values are: {all_values}. " + "Please ensure that the system you want to add does not already" + " exist in the database under a different name or spelling. " "If you want to use an existing system in the database, " - "please change the corresponding Device object in the NWB file. " - "Entering 'N' will raise an exception." + "please change the corresponding Device object in the NWB file." + " Entering 'N' will raise an exception." ) val = input( - f"Do you want to add data acquisition device system '{system}' to the database? (y/N)" + f"Do you want to add data acquisition device system '{system}'" + + " to the database? (y/N)" ) if val.lower() in ["y", "yes"]: key = {"data_acquisition_device_system": system} DataAcquisitionDeviceSystem.insert1(key, skip_duplicates=True) else: raise PopulateException( - f"User chose not to add data acquisition device system '{system}' to the database." + "User chose not to add data acquisition device system " + + f"'{system}' to the database." ) return system @classmethod def _add_amplifier(cls, amplifier): - """Check the amplifier value. If it is not in the database, prompt the user to add the value to the database. + """Check amplifier value. If not in db, prompt user to add. Parameters ---------- @@ -237,7 +250,8 @@ def _add_amplifier(cls, amplifier): Raises ------ PopulateException - If user chooses not to add a device amplifier value to the database when prompted. + If user chooses not to add a device amplifier value to the database + when prompted. Returns ------- @@ -253,16 +267,17 @@ def _add_amplifier(cls, amplifier): ).tolist() if amplifier not in all_values: print( - f"\nData acquisition device amplifier '{amplifier}' was not found in the database. " - f"The current values are: {all_values}. " - "Please ensure that the amplifier you want to add does not already " - "exist in the database under a different name or spelling. " - "If you want to use an existing name in the database, " - "please change the corresponding Device object in the NWB file. " - "Entering 'N' will raise an exception." + f"\nData acquisition device amplifier '{amplifier}' was not " + f"found in the database. The current values are: {all_values}. " + "Please ensure that the amplifier you want to add does not " + "already exist in the database under a different name or " + "spelling. If you want to use an existing name in the database," + " please change the corresponding Device object in the NWB " + "file. Entering 'N' will raise an exception." ) val = input( - f"Do you want to add data acquisition device amplifier '{amplifier}' to the database? (y/N)" + "Do you want to add data acquisition device amplifier " + + f"'{amplifier}' to the database? (y/N)" ) if val.lower() in ["y", "yes"]: key = {"data_acquisition_device_amplifier": amplifier} @@ -271,7 +286,8 @@ def _add_amplifier(cls, amplifier): ) else: raise PopulateException( - f"User chose not to add data acquisition device amplifier '{amplifier}' to the database." + "User chose not to add data acquisition device amplifier " + + f"'{amplifier}' to the database." ) return amplifier @@ -306,14 +322,17 @@ def insert_from_nwbfile(cls, nwbf): for device in nwbf.devices.values(): if isinstance(device, ndx_franklab_novela.CameraDevice): device_dict = dict() - # TODO ideally the ID is not encoded in the name formatted in a particular way - # device.name must have the form "[any string without a space, usually camera] [int]" - device_dict["camera_id"] = int(str.split(device.name)[1]) - device_dict["camera_name"] = device.camera_name - device_dict["manufacturer"] = device.manufacturer - device_dict["model"] = device.model - device_dict["lens"] = device.lens - device_dict["meters_per_pixel"] = device.meters_per_pixel + # TODO ideally the ID is not encoded in the name formatted in a + # particular way device.name must have the form "[any string + # without a space, usually camera] [int]" + device_dict = { + "camera_id": int(device.name.split()[1]), + "camera_name": device.camera_name, + "manufacturer": device.manufacturer, + "model": device.model, + "lens": device.lens, + "meters_per_pixel": device.meters_per_pixel, + } cls.insert1(device_dict, skip_duplicates=True) device_name_list.append(device_dict["camera_name"]) if device_name_list: @@ -326,48 +345,50 @@ def insert_from_nwbfile(cls, nwbf): @schema class ProbeType(dj.Manual): definition = """ - # Type/category of probe, e.g., Neuropixels 1.0 or NeuroNexus X-Y-Z, regardless of configuration. - # This is a controlled vocabulary of probe type names. - # This is separated from Probe because probes like the Neuropixels 1.0 can have different dynamic configurations, - # e.g. channel maps. + # Type/category of probe regardless of configuration. Controlled vocabulary + # of probe type names. e.g., Neuropixels 1.0 or NeuroNexus X-Y-Z, etc. + # Separated from Probe because probes like the Neuropixels 1.0 can have + # different dynamic configurations e.g. channel maps. + probe_type: varchar(80) --- - probe_description: varchar(2000) # description of this probe - manufacturer = "": varchar(200) # manufacturer of this probe - num_shanks: int # number of shanks on this probe + probe_description: varchar(2000) # description of this probe + manufacturer = "": varchar(200) # manufacturer of this probe + num_shanks: int # number of shanks on this probe """ @schema class Probe(dj.Manual): definition = """ - # A configuration of a ProbeType. For most probe types, there is only one configuration, and that configuration - # should always be used. For Neuropixels probes, the specific channel map (which electrodes are used, - # where are they, and in what order) can differ between users and sessions, and each configuration should have a - # different ProbeType. - probe_id: varchar(80) # a unique ID for this probe and dynamic configuration + # A configuration of a ProbeType. For most probe types, there is only one, + # which should always be used. For Neuropixels, the channel map (which + # electrodes used, where they are, and in what order) can differ between + # users and sessions. Each config should have a different ProbeType. + probe_id: varchar(80) # a unique ID for this probe & dynamic config --- - -> ProbeType # the type of probe, selected from a controlled list of probe types - -> [nullable] DataAcquisitionDevice # the data acquisition device used with this Probe - contact_side_numbering: enum("True", "False") # if True, then electrode contacts are facing you when numbering them + -> ProbeType # Type of probe, selected from a controlled list + -> [nullable] DataAcquisitionDevice # the data actwquisition device used + contact_side_numbering: enum("True", "False") # Facing you when numbering """ class Shank(dj.Part): definition = """ -> Probe - probe_shank: int # shank number within probe. should be unique within a Probe + probe_shank: int # unique shank number within probe. """ class Electrode(dj.Part): definition = """ + # Electrode configuration, with ID, contact size, X/Y/Z coordinates -> Probe.Shank - probe_electrode: int # electrode ID that is output from the data acquisition system - # probe_electrode should be unique within a Probe + probe_electrode: int # electrode ID, output from acquisition + # system. Unique within a Probe --- contact_size = NULL: float # (um) contact size - rel_x = NULL: float # (um) x coordinate of the electrode within the probe - rel_y = NULL: float # (um) y coordinate of the electrode within the probe - rel_z = NULL: float # (um) z coordinate of the electrode within the probe + rel_x = NULL: float # (um) x coordinate of electrode + rel_y = NULL: float # (um) y coordinate of electrode + rel_z = NULL: float # (um) z coordinate of electrode """ @classmethod @@ -379,7 +400,8 @@ def insert_from_nwbfile(cls, nwbf, config): nwbf : pynwb.NWBFile The source NWB file object. config : dict - Dictionary read from a user-defined YAML file containing values to replace in the NWB file. + Dictionary read from a user-defined YAML file containing values to + replace in the NWB file. Returns ------- @@ -396,7 +418,8 @@ def insert_from_nwbfile(cls, nwbf, config): num_shanks = 0 if probe_type in ndx_probes: - # read probe properties into new_probe_dict from PyNWB extension probe object + # read probe properties into new_probe_dict from PyNWB extension + # probe object nwb_probe_obj = ndx_probes[probe_type] cls.__read_ndx_probe_data( nwb_probe_obj, @@ -412,13 +435,16 @@ def insert_from_nwbfile(cls, nwbf, config): shank_dict ), "`num_shanks` is not equal to the number of shanks." - # if probe id already exists, do not overwrite anything or create new Shanks and Electrodes - # TODO test whether the Shanks and Electrodes in the NWB file match the ones in the database + # if probe id already exists, do not overwrite anything or create + # new Shanks and Electrodes + # TODO: test whether the Shanks and Electrodes in the NWB file match + # the ones in the database query = Probe & {"probe_id": new_probe_dict["probe_id"]} if len(query) > 0: print( - f"Probe ID '{new_probe_dict['probe_id']}' already exists in the database. Spyglass will use " - "that and not create a new Probe, Shanks, or Electrodes." + f"Probe ID '{new_probe_dict['probe_id']}' already exists in" + " the database. Spyglass will use that and not create a new" + " Probe, Shanks, or Electrodes." ) continue @@ -438,14 +464,17 @@ def insert_from_nwbfile(cls, nwbf, config): @classmethod def get_all_probe_names(cls, nwbf, config): - """Get a list of all device names in the NWB file, after appending and overwriting by the config file. + """Get a list of all device names in the NWB. + + Includes all devices, after appending/overwriting by the config file. Parameters ---------- nwbf : pynwb.NWBFile The source NWB file object. config : dict - Dictionary read from a user-defined YAML file containing values to replace in the NWB file. + Dictionary read from a user-defined YAML file containing values to + replace in the NWB file. Returns ------- @@ -453,21 +482,22 @@ def get_all_probe_names(cls, nwbf, config): List of data acquisition object names found in the NWB file. """ - # make a dict mapping probe type to PyNWB object for all devices in the NWB file that are - # of type ndx_franklab_novela.Probe and thus have the required metadata + # make a dict mapping probe type to PyNWB object for all devices in the + # NWB file that are of type ndx_franklab_novela.Probe and thus have the + # required metadata ndx_probes = { device_obj.probe_type: device_obj for device_obj in nwbf.devices.values() if isinstance(device_obj, ndx_franklab_novela.Probe) } - # make a dict mapping probe type to dict of device metadata from the config YAML if exists - if "Probe" in config: - config_probes = [ - probe_dict["probe_type"] for probe_dict in config["Probe"] - ] - else: - config_probes = list() + # make a dict mapping probe type to dict of device metadata from the + # config YAML if exists + config_probes = ( + [probe_dict["probe_type"] for probe_dict in config["Probe"]] + if "Probe" in config + else list() + ) # get all the probe types from the NWB file plus the config YAML all_probes_types = set(ndx_probes.keys()).union(set(config_probes)) @@ -484,48 +514,46 @@ def __read_ndx_probe_data( elect_dict: dict, ): # construct dictionary of values to add to ProbeType - new_probe_type_dict["manufacturer"] = ( - getattr(nwb_probe_obj, "manufacturer") or "" + new_probe_type_dict.update( + { + "manufacturer": getattr(nwb_probe_obj, "manufacturer") or "", + "probe_type": nwb_probe_obj.probe_type, + "probe_description": nwb_probe_obj.probe_description, + "num_shanks": len(nwb_probe_obj.shanks), + } ) - new_probe_type_dict["probe_type"] = nwb_probe_obj.probe_type - new_probe_type_dict[ - "probe_description" - ] = nwb_probe_obj.probe_description - new_probe_type_dict["num_shanks"] = len(nwb_probe_obj.shanks) cls._add_probe_type(new_probe_type_dict) - new_probe_dict["probe_id"] = nwb_probe_obj.probe_type - new_probe_dict["probe_type"] = nwb_probe_obj.probe_type - new_probe_dict["contact_side_numbering"] = ( - "True" if nwb_probe_obj.contact_side_numbering else "False" + new_probe_dict.update( + { + "probe_id": nwb_probe_obj.probe_type, + "probe_type": nwb_probe_obj.probe_type, + "contact_side_numbering": "True" + if nwb_probe_obj.contact_side_numbering + else "False", + } ) - # go through the shanks and add each one to the Shank table for shank in nwb_probe_obj.shanks.values(): - shank_dict[shank.name] = dict() - shank_dict[shank.name]["probe_id"] = new_probe_dict["probe_type"] - shank_dict[shank.name]["probe_shank"] = int(shank.name) + shank_dict[shank.name] = { + "probe_id": new_probe_dict["probe_type"], + "probe_shank": int(shank.name), + } # go through the electrodes and add each one to the Electrode table for electrode in shank.shanks_electrodes.values(): - # the next line will need to be fixed if we have different sized contacts on a shank - elect_dict[electrode.name] = dict() - elect_dict[electrode.name]["probe_id"] = new_probe_dict[ - "probe_type" - ] - elect_dict[electrode.name]["probe_shank"] = shank_dict[ - shank.name - ]["probe_shank"] - elect_dict[electrode.name][ - "contact_size" - ] = nwb_probe_obj.contact_size - elect_dict[electrode.name]["probe_electrode"] = int( - electrode.name - ) - elect_dict[electrode.name]["rel_x"] = electrode.rel_x - elect_dict[electrode.name]["rel_y"] = electrode.rel_y - elect_dict[electrode.name]["rel_z"] = electrode.rel_z + # the next line will need to be fixed if we have different sized + # contacts on a shank + elect_dict[electrode.name] = { + "probe_id": new_probe_dict["probe_type"], + "probe_shank": shank_dict[shank.name]["probe_shank"], + "contact_size": nwb_probe_obj.contact_size, + "probe_electrode": int(electrode.name), + "rel_x": electrode.rel_x, + "rel_y": electrode.rel_y, + "rel_z": electrode.rel_z, + } @classmethod def _add_probe_type(cls, new_probe_type_dict): @@ -539,7 +567,8 @@ def _add_probe_type(cls, new_probe_type_dict): Raises ------ PopulateException - If user chooses not to add a probe type to the database when prompted. + If user chooses not to add a probe type to the database when + prompted. Returns ------- @@ -552,29 +581,32 @@ def _add_probe_type(cls, new_probe_type_dict): print( f"\nProbe type '{probe_type}' was not found in the database. " f"The current values are: {all_values}. " - "Please ensure that the probe type you want to add does not already " - "exist in the database under a different name or spelling. " - "If you want to use an existing name in the database, " - "please change the corresponding Probe object in the NWB file. " - "Entering 'N' will raise an exception." + "Please ensure that the probe type you want to add does not " + "already exist in the database under a different name or " + "spelling. If you want to use an existing name in the " + "database, please change the corresponding Probe object in the " + "NWB file. Entering 'N' will raise an exception." ) val = input( - f"Do you want to add probe type '{probe_type}' to the database? (y/N)" + f"Do you want to add probe type '{probe_type}' to the database?" + + " (y/N)" ) if val.lower() in ["y", "yes"]: ProbeType.insert1(new_probe_type_dict, skip_duplicates=True) return raise PopulateException( - f"User chose not to add probe type '{probe_type}' to the database." + f"User chose not to add probe type '{probe_type}' to the " + + "database." ) - # effectively else (entry exists) - # check whether the values provided match the values stored in the database + # else / entry exists: check whether the values provided match the + # values stored in the database db_dict = (ProbeType & {"probe_type": probe_type}).fetch1() if db_dict != new_probe_type_dict: raise PopulateException( - f"\nProbe type properties of PyNWB Probe object with name '{probe_type}': " - f"{new_probe_type_dict} do not match properties of the corresponding database entry: {db_dict}." + "\nProbe type properties of PyNWB Probe object with name " + f"'{probe_type}': {new_probe_type_dict} do not match properties" + f" of the corresponding database entry: {db_dict}." ) return probe_type @@ -587,23 +619,21 @@ def create_from_nwbfile( probe_type: str, contact_side_numbering: bool, ): - """Create a Probe entry and corresponding part table entries using the data in the NWB file. + """Create master/part Probe entry from the NWB file. - This method will parse the electrodes in the electrodes table, electrode groups (as shanks), and devices - (as probes) in the NWB file, but only ones that are associated with the device that matches the given + This method will parse the electrodes in the electrodes table, electrode + groups (as shanks), and devices (as probes) in the NWB file, but only + ones that are associated with the device that matches the given `nwb_device_name`. - Note that this code assumes the relatively standard convention where the NWB device corresponds to a Probe, - the NWB electrode group corresponds to a Shank, and the NWB electrode corresponds to an Electrode. + Note that this code assumes the relatively standard convention where the + NWB device corresponds to a Probe, the NWB electrode group corresponds + to a Shank, and the NWB electrode corresponds to an Electrode. - Example usage: - ``` - sgc.Probe.create_from_nwbfile( - nwbfile=nwb_file_name, - nwb_device_name="Device", + Example usage: ``` sgc.Probe.create_from_nwbfile( + nwbfile=nwb_file_name, nwb_device_name="Device", probe_id="Neuropixels 1.0 Giocomo Lab Configuration", - probe_type="Neuropixels 1.0", - contact_side_numbering=True + probe_type="Neuropixels 1.0", contact_side_numbering=True ) ``` @@ -612,13 +642,17 @@ def create_from_nwbfile( nwb_file_name : str The name of the NWB file. nwb_device_name : str - The name of the PyNWB Device object that represents the probe to read in the NWB file. + The name of the PyNWB Device object that represents the probe to + read in the NWB file. probe_id : str - A unique ID for the probe and its configuration, to be used as the primary key for the new Probe entry. + A unique ID for the probe and its configuration, to be used as the + primary key for the new Probe entry. probe_type : str - The existing ProbeType entry that represents the type of probe being created. It must exist. + The existing ProbeType entry that represents the type of probe being + created. It must exist. contact_side_numbering : bool - Whether the electrode contacts are facing you when numbering them. Stored in the new Probe entry. + Whether the electrode contacts are facing you when numbering them. + Stored in the new Probe entry. """ from .common_nwbfile import Nwbfile @@ -633,49 +667,51 @@ def create_from_nwbfile( ) return - new_probe_dict = dict() - shank_dict = dict() - elect_dict = dict() - - new_probe_dict["probe_id"] = probe_id - new_probe_dict["probe_type"] = probe_type - new_probe_dict["contact_side_numbering"] = ( - "True" if contact_side_numbering else "False" - ) + new_probe_dict = { + "probe_id": probe_id, + "probe_type": probe_type, + "contact_side_numbering": ( + "True" if contact_side_numbering else "False" + ), + } + shank_dict = {} + elect_dict = {} # iterate through the electrodes table in the NWB file # and use the group column (ElectrodeGroup) to create shanks # and use the device attribute of each ElectrodeGroup to create a probe - created_shanks = dict() # map device name to shank_index (int) + created_shanks = {} # map device name to shank_index (int) device_found = False for elec_index in range(len(nwbfile.electrodes)): electrode_group = nwbfile.electrodes[elec_index, "group"] eg_device_name = electrode_group.device.name - # only look at electrodes where the associated device is the one specified + # only look at electrodes where the associated device is the one + # specified if eg_device_name == nwb_device_name: device_found = True - # if a Shank has not yet been created from the electrode group, then create it + # if a Shank has not yet been created from the electrode group, + # then create it if electrode_group.name not in created_shanks: shank_index = len(created_shanks) created_shanks[electrode_group.name] = shank_index # build the dictionary of Probe.Shank data - shank_dict[shank_index] = dict() - shank_dict[shank_index]["probe_id"] = new_probe_dict[ - "probe_id" - ] - shank_dict[shank_index]["probe_shank"] = shank_index + shank_dict[shank_index] = { + "probe_id": new_probe_dict["probe_id"], + "probe_shank": shank_index, + } # get the probe shank index associated with this Electrode probe_shank = created_shanks[electrode_group.name] # build the dictionary of Probe.Electrode data - elect_dict[elec_index] = dict() - elect_dict[elec_index]["probe_id"] = new_probe_dict["probe_id"] - elect_dict[elec_index]["probe_shank"] = probe_shank - elect_dict[elec_index]["probe_electrode"] = elec_index + elect_dict[elec_index] = { + "probe_id": new_probe_dict["probe_id"], + "probe_shank": probe_shank, + "probe_electrode": elec_index, + } if "rel_x" in nwbfile.electrodes[elec_index]: elect_dict[elec_index]["rel_x"] = nwbfile.electrodes[ elec_index, "rel_x" @@ -691,7 +727,8 @@ def create_from_nwbfile( if not device_found: print( - f"No electrodes in the NWB file were associated with a device named '{nwb_device_name}'." + "No electrodes in the NWB file were associated with a device " + + f"named '{nwb_device_name}'." ) return diff --git a/src/spyglass/common/common_dio.py b/src/spyglass/common/common_dio.py index f5ad37de1..e8c12c662 100644 --- a/src/spyglass/common/common_dio.py +++ b/src/spyglass/common/common_dio.py @@ -34,11 +34,12 @@ def make(self, key): ) if behav_events is None: print( - f"No conforming behavioral events data interface found in {nwb_file_name}\n" + "No conforming behavioral events data interface found in " + + f"{nwb_file_name}\n" ) return - # the times for these events correspond to the valid times for the raw data + # Times for these events correspond to the valid times for the raw data key["interval_list_name"] = ( Raw() & {"nwb_file_name": nwb_file_name} ).fetch1("interval_list_name") @@ -55,8 +56,10 @@ def plot_all_dio_events(self): Examples -------- - > (DIOEvents & {'nwb_file_name': 'arthur20220314_.nwb'}).plot_all_dio_events() - > (DIOEvents & [{'nwb_file_name': "arthur20220314_.nwb"}, {"nwb_file_name": "arthur20220316_.nwb"}]).plot_all_dio_events() + > restr1 = {'nwb_file_name': 'arthur20220314_.nwb'} + > restr2 = {'nwb_file_name': 'arthur20220316_.nwb'} + > (DIOEvents & restr1).plot_all_dio_events() + > (DIOEvents & [restr1, restr2]).plot_all_dio_events() """ behavioral_events = self.fetch_nwb() diff --git a/src/spyglass/common/common_filter.py b/src/spyglass/common/common_filter.py index f7f741984..a2550330e 100644 --- a/src/spyglass/common/common_filter.py +++ b/src/spyglass/common/common_filter.py @@ -21,8 +21,9 @@ def _import_ghostipy(): return gsp except ImportError as e: raise ImportError( - "You must install ghostipy to use filtering methods. Please note that to install ghostipy on " - "an Mac M1, you must first install pyfftw from conda-forge." + "You must install ghostipy to use filtering methods. Please note " + "that to install ghostipy on an Mac M1, you must first install " + "pyfftw from conda-forge." ) from e @@ -33,201 +34,264 @@ class FirFilterParameters(dj.Manual): filter_sampling_rate: int # sampling rate for this filter --- filter_type: enum("lowpass", "highpass", "bandpass") - filter_low_stop = 0: float # lowest frequency for stop band for low frequency side of filter - filter_low_pass = 0: float # lowest frequency for pass band of low frequency side of filter - filter_high_pass = 0: float # highest frequency for pass band for high frequency side of filter - filter_high_stop = 0: float # highest frequency for stop band of high frequency side of filter + filter_low_stop = 0: float # lowest freq for stop band for low filt + filter_low_pass = 0: float # lowest freq for pass band of low filt + filter_high_pass = 0: float # hi'est freq for pass band for high filt + filter_high_stop = 0: float # hi'est freq for stop band of high filt filter_comments: varchar(2000) # comments about the filter - filter_band_edges: blob # numpy array containing the filter bands (redundant with individual parameters) - filter_coeff: longblob # numpy array containing the filter coefficients + filter_band_edges: blob # numpy array of filter bands + # redundant with individual parameters + filter_coeff: longblob # numpy array of filter coefficients """ - def add_filter(self, filter_name, fs, filter_type, band_edges, comments=""): - gsp = _import_ghostipy() + def add_filter( + self, + filter_name: str, + fs: float, + filter_type: str, + band_edges: list, + comments: str = "", + ) -> None: + """Add filter to the Filter table. + + Parameters + ---------- + filter_name: str + The name of the filter. + fs: float + The filter sampling rate. + filter_type: str + The type of the filter ('lowpass', 'highpass', or 'bandpass'). + band_edges: List[float] + The band edges for the filter. + comments: str, optional) + Additional comments for the filter. Default "". + + Returns + ------- + None + Returns None if there is an error in the filter type or band + frequencies. + + Raises + ------ + Exception: + Raises an exception if an unexpected filter type is encountered. + """ + VALID_FILTERS = {"lowpass": 2, "highpass": 2, "bandpass": 4} + FILTER_ERR = "Error in Filter.add_filter: " + FILTER_N_ERR = FILTER_ERR + "filter {} requires {} band_frequencies." - # add an FIR bandpass filter of the specified type ('lowpass', 'highpass', or 'bandpass'). + # add an FIR bandpass filter of the specified type. # band_edges should be as follows: - # low pass filter: [high_pass high_stop] - # high pass filter: [low stop low pass] - # band pass filter: [low_stop low_pass high_pass high_stop]. - if filter_type not in ["lowpass", "highpass", "bandpass"]: + # low pass : [high_pass high_stop] + # high pass: [low stop low pass] + # band pass: [low_stop low_pass high_pass high_stop]. + if filter_type not in VALID_FILTERS: print( - "Error in Filter.add_filter: filter type {} is not " - "lowpass" - ", " - "highpass" - " or " - """bandpass""".format(filter_type) + FILTER_ERR + + f"{filter_type} not valid type: {VALID_FILTERS.keys()}" ) return None - p = 2 # transition spline will be quadratic - if filter_type == "lowpass" or filter_type == "highpass": - # check that two frequencies were passed in and that they are in the right order - if len(band_edges) != 2: - print( - "Error in Filter.add_filter: lowpass and highpass filter requires two band_frequencies" - ) - return None - tw = band_edges[1] - band_edges[0] + if not len(band_edges) == VALID_FILTERS[filter_type]: + print(FILTER_N_ERR.format(filter_name, VALID_FILTERS[filter_type])) + return None - elif filter_type == "bandpass": - if len(band_edges) != 4: - print( - "Error in Filter.add_filter: bandpass filter requires four band_frequencies." - ) - return None - # the transition width is the mean of the widths of left and right transition regions - tw = ( + gsp = _import_ghostipy() + TRANS_SPLINE = 2 # transition spline will be quadratic + + if filter_type != "bandpass": + transition_width = band_edges[1] - band_edges[0] + + else: + # transition width is mean of left and right transition regions + transition_width = ( (band_edges[1] - band_edges[0]) + (band_edges[3] - band_edges[2]) ) / 2.0 - else: - raise Exception(f"Unexpected filter type: {filter_type}") - - numtaps = gsp.estimate_taps(fs, tw) - filterdict = dict() - filterdict["filter_name"] = filter_name - filterdict["filter_sampling_rate"] = fs - filterdict["filter_comments"] = comments + numtaps = gsp.estimate_taps(fs, transition_width) + filterdict = { + "filter_type": filter_type, + "filter_name": filter_name, + "filter_sampling_rate": fs, + "filter_comments": comments, + "filter_low_stop": 0, + "filter_low_pass": 0, + "filter_high_pass": 0, + "filter_high_stop": 0, + "filter_band_edges": np.asarray(band_edges), + } # set the desired frequency response if filter_type == "lowpass": desired = [1, 0] - filterdict["filter_low_stop"] = 0 - filterdict["filter_low_pass"] = 0 - filterdict["filter_high_pass"] = band_edges[0] - filterdict["filter_high_stop"] = band_edges[1] + pass_stop_dict = { + "filter_high_pass": band_edges[0], + "filter_high_stop": band_edges[1], + } elif filter_type == "highpass": desired = [0, 1] - filterdict["filter_low_stop"] = band_edges[0] - filterdict["filter_low_pass"] = band_edges[1] - filterdict["filter_high_pass"] = 0 - filterdict["filter_high_stop"] = 0 + pass_stop_dict = { + "filter_low_stop": band_edges[0], + "filter_low_pass": band_edges[1], + } else: desired = [0, 1, 1, 0] - filterdict["filter_low_stop"] = band_edges[0] - filterdict["filter_low_pass"] = band_edges[1] - filterdict["filter_high_pass"] = band_edges[2] - filterdict["filter_high_stop"] = band_edges[3] - filterdict["filter_type"] = filter_type - filterdict["filter_band_edges"] = np.asarray(band_edges) + pass_stop_dict = { + "filter_low_stop": band_edges[0], + "filter_low_pass": band_edges[1], + "filter_high_pass": band_edges[2], + "filter_high_stop": band_edges[3], + } + # create 1d array for coefficients - filterdict["filter_coeff"] = np.array( - gsp.firdesign(numtaps, band_edges, desired, fs=fs, p=p), ndmin=1 + filterdict.update( + { + **pass_stop_dict, + "filter_coeff": np.array( + gsp.firdesign( + numtaps, band_edges, desired, fs=fs, p=TRANS_SPLINE + ), + ndmin=1, + ), + } ) - # add this filter to the table + self.insert1(filterdict, skip_duplicates=True) - def plot_magnitude(self, filter_name, fs): - filter = ( + def _filter_restrict(self, filter_name, fs): + return ( self & {"filter_name": filter_name} & {"filter_sampling_rate": fs} - ).fetch(as_dict=True) - f = filter[0] + ).fetch1(as_dict=True) + + def plot_magnitude(self, filter_name, fs): + filter_dict = self._filter_restrict(filter_name, fs) plt.figure() - w, h = signal.freqz(filter[0]["filter_coeff"], worN=65536) + w, h = signal.freqz(filter_dict["filter_coeff"], worN=65536) magnitude = 20 * np.log10(np.abs(h)) plt.plot(w / np.pi * fs / 2, magnitude) plt.xlabel("Frequency (Hz)") plt.ylabel("Magnitude") plt.title("Frequency Response") - plt.xlim(0, np.max(f["filter_coeffand_edges"] * 2)) + plt.xlim(0, np.max(filter_dict["filter_coeffand_edges"] * 2)) plt.ylim(np.min(magnitude), -1 * np.min(magnitude) * 0.1) plt.grid(True) def plot_fir_filter(self, filter_name, fs): - filter = ( - self & {"filter_name": filter_name} & {"filter_sampling_rate": fs} - ).fetch(as_dict=True) - f = filter[0] + filter_dict = self._filter_restrict(filter_name, fs) plt.figure() plt.clf() - plt.plot(f["filter_coeff"], "k") + plt.plot(filter_dict["filter_coeff"], "k") plt.xlabel("Coefficient") plt.ylabel("Magnitude") plt.title("Filter Taps") plt.grid(True) def filter_delay(self, filter_name, fs): - # return the filter delay - filter = ( - self & {"filter_name": filter_name} & {"filter_sampling_rate": fs} - ).fetch(as_dict=True) - return self.calc_filter_delay(filter["filter_coeff"]) + return self.calc_filter_delay( + self._filter_restrict(filter_name, fs)["filter_coeff"] + ) + + def _time_bound_check(self, start, stop, all, nsamples): + timestamp_warn = "Interval time warning: " + if start < all[0]: + warnings.warn( + timestamp_warn + + "start time smaller than first timestamp, " + + f"substituting first: {start} < {all[0]}" + ) + start = all[0] + + if stop > all[-1]: + warnings.warn( + timestamp_warn + + "stop time larger than last timestamp, " + + f"substituting last: {stop} < {all[-1]}" + ) + stop = all[-1] + + frm, to = np.searchsorted(all, (start, stop)) + to = min(to, nsamples) + return frm, to def filter_data_nwb( self, - analysis_file_abs_path, - eseries, - filter_coeff, - valid_times, - electrode_ids, - decimation, + analysis_file_abs_path: str, + eseries: pynwb.ecephys.ElectricalSeries, + filter_coeff: np.ndarray, + valid_times: np.ndarray, + electrode_ids: list, + decimation: int, description: str = "filtered data", type: Union[None, str] = None, ): """ - :param analysis_nwb_file_name: str full path to previously created analysis nwb file where filtered data - should be stored. This also has the name of the original NWB file where the data will be taken from - :param eseries: electrical series with data to be filtered - :param filter_coeff: numpy array with filter coefficients for FIR filter - :param valid_times: 2D numpy array with start and stop times of intervals to be filtered - :param electrode_ids: list of electrode_ids to filter - :param decimation: int decimation factor - :return: The NWB object id of the filtered data (str), list containing first and last timestamp - - This function takes data and timestamps from an NWB electrical series and filters them using the ghostipy - package, saving the result as a new electricalseries in the nwb_file_name, which should have previously been - created and linked to the original NWB file using common_session.AnalysisNwbfile.create() + Filter data from an NWB electrical series using the ghostipy package, + and save the result as a new electrical series in the analysis NWB file. + + Parameters + ---------- + analysis_file_abs_path : str + Full path to the analysis NWB file. + eseries : pynwb.ecephys.ElectricalSeries + Electrical series with data to be filtered. + filter_coeff : np.ndarray + Array with filter coefficients for FIR filter. + valid_times : np.ndarray + Array with start and stop times of intervals to be filtered. + electrode_ids : list + List of electrode IDs to filter. + decimation : int + Decimation factor. + description : str + Description of the filtered data. + data_type : Union[None, str] + Type of data (e.g., "LFP"). + + Returns + ------- + tuple + The NWB object ID of the filtered data and a list containing the + first and last timestamps. """ + MEM_USE_LIMIT = 0.9 # % of RAM use permited + gsp = _import_ghostipy() data_on_disk = eseries.data timestamps_on_disk = eseries.timestamps - n_dim = len(data_on_disk.shape) + n_samples = len(timestamps_on_disk) - # find the time_axis = 0 if data_on_disk.shape[0] == n_samples else 1 electrode_axis = 1 - time_axis + n_electrodes = data_on_disk.shape[electrode_axis] - input_dim_restrictions = [None] * n_dim + input_dim_restrictions = [None] * len(data_on_disk.shape) - # to get the input dimension restrictions we need to look at the electrode table for the eseries and get - # the indices from that + # Get input dimension restrictions input_dim_restrictions[electrode_axis] = np.s_[ get_electrode_indices(eseries, electrode_ids) ] indices = [] - output_shape_list = [0] * n_dim + output_shape_list = input_dim_restrictions output_shape_list[electrode_axis] = len(electrode_ids) - output_offsets = [0] - - timestamp_size = timestamps_on_disk[0].itemsize - timestamp_dtype = timestamps_on_disk[0].dtype - data_size = data_on_disk[0][0].itemsize data_dtype = data_on_disk[0][0].dtype filter_delay = self.calc_filter_delay(filter_coeff) + + output_offsets = [0] + for a_start, a_stop in valid_times: - if a_start < timestamps_on_disk[0]: - warnings.warn( - f"Interval start time {a_start} is smaller than first timestamp {timestamps_on_disk[0]}, " - "using first timestamp instead" - ) - a_start = timestamps_on_disk[0] - if a_stop > timestamps_on_disk[-1]: - warnings.warn( - f"Interval stop time {a_stop} is larger than last timestamp {timestamps_on_disk[-1]}, " - "using last timestamp instead" - ) - a_stop = timestamps_on_disk[-1] - frm, to = np.searchsorted(timestamps_on_disk, (a_start, a_stop)) - if to > n_samples: - to = n_samples + frm, to = self._time_bound_check( + a_start, a_stop, timestamps_on_disk, n_samples + ) + indices.append((frm, to)) - shape, dtype = gsp.filter_data_fir( + + shape, _ = gsp.filter_data_fir( data_on_disk, filter_coeff, axis=time_axis, @@ -240,119 +304,107 @@ def filter_data_nwb( output_offsets.append(output_offsets[-1] + shape[time_axis]) output_shape_list[time_axis] += shape[time_axis] - # open the nwb file to create the dynamic table region and electrode series, then write and close the file + # Create dynamic table region and electrode series, write/close file with pynwb.NWBHDF5IO( path=analysis_file_abs_path, mode="a", load_namespaces=True ) as io: nwbf = io.read() + # get the indices of the electrodes in the electrode table elect_ind = get_electrode_indices(nwbf, electrode_ids) electrode_table_region = nwbf.create_electrode_table_region( elect_ind, "filtered electrode table" ) - eseries_name = "filtered data" es = pynwb.ecephys.ElectricalSeries( - name=eseries_name, + name="filtered data", data=np.empty(tuple(output_shape_list), dtype=data_dtype), electrodes=electrode_table_region, timestamps=np.empty(output_shape_list[time_axis]), description=description, ) if type == "LFP": - lfp = pynwb.ecephys.LFP(electrical_series=es) ecephys_module = nwbf.create_processing_module( name="ecephys", description=description ) - ecephys_module.add(lfp) + ecephys_module.add(pynwb.ecephys.LFP(electrical_series=es)) else: nwbf.add_scratch(es) + io.write(nwbf) - # reload the NWB file to get the h5py objects for the data and the timestamps - with pynwb.NWBHDF5IO( - path=analysis_file_abs_path, mode="a", load_namespaces=True - ) as io: - nwbf = io.read() - es = nwbf.objects[es.object_id] - filtered_data = es.data - new_timestamps = es.timestamps - indices = np.array(indices, ndmin=2) - # Filter and write the output dataset - ts_offset = 0 - - print("Filtering data") - for ii, (start, stop) in enumerate(indices): - # calculate the size of the timestamps and the data and determine whether they - # can be loaded into < 90% of available RAM - mem = psutil.virtual_memory() - interval_samples = stop - start - if ( - interval_samples - * (timestamp_size + n_electrodes * data_size) - < 0.9 * mem.available - ): - print(f"Interval {ii}: loading data into memory") - timestamps = np.asarray( - timestamps_on_disk[start:stop], - dtype=timestamp_dtype, - ) - if time_axis == 0: - data = np.asarray( - data_on_disk[start:stop, :], dtype=data_dtype - ) - else: - data = np.asarray( - data_on_disk[:, start:stop], dtype=data_dtype - ) - extracted_ts = timestamps[0::decimation] - new_timestamps[ - ts_offset : ts_offset + len(extracted_ts) - ] = extracted_ts - ts_offset += len(extracted_ts) - # filter the data - gsp.filter_data_fir( - data, - filter_coeff, - axis=time_axis, - input_index_bounds=[0, interval_samples - 1], - output_index_bounds=[ - filter_delay, - filter_delay + stop - start, - ], - ds=decimation, - input_dim_restrictions=input_dim_restrictions, - outarray=filtered_data, - output_offset=output_offsets[ii], + # Reload NWB file to get h5py objects for data/timestamps + # NOTE: CBroz - why io context within io context? Unindenting + with pynwb.NWBHDF5IO( + path=analysis_file_abs_path, mode="a", load_namespaces=True + ) as io: + nwbf = io.read() + es = nwbf.objects[es.object_id] + filtered_data = es.data + new_timestamps = es.timestamps + indices = np.array(indices, ndmin=2) + # Filter and write the output dataset + ts_offset = 0 + + print("Filtering data") + for ii, (start, stop) in enumerate(indices): + # Calc size of timestamps + data, check if < 90% of RAM + interval_samples = stop - start + req_mem = interval_samples * ( + timestamps_on_disk[0].itemsize + + n_electrodes * data_on_disk[0][0].itemsize + ) + if req_mem < MEM_USE_LIMIT * psutil.virtual_memory(): + print(f"Interval {ii}: loading data into memory") + timestamps = np.asarray( + timestamps_on_disk[start:stop], + dtype=timestamps_on_disk[0].dtype, + ) + if time_axis == 0: + data = np.asarray( + data_on_disk[start:stop, :], dtype=data_dtype ) else: - print(f"Interval {ii}: leaving data on disk") - data = data_on_disk - timestamps = timestamps_on_disk - extracted_ts = timestamps[start:stop:decimation] - new_timestamps[ - ts_offset : ts_offset + len(extracted_ts) - ] = extracted_ts - ts_offset += len(extracted_ts) - # filter the data - gsp.filter_data_fir( - data, - filter_coeff, - axis=time_axis, - input_index_bounds=[start, stop], - output_index_bounds=[ - filter_delay, - filter_delay + stop - start, - ], - ds=decimation, - input_dim_restrictions=input_dim_restrictions, - outarray=filtered_data, - output_offset=output_offsets[ii], + data = np.asarray( + data_on_disk[:, start:stop], dtype=data_dtype ) + extracted_ts = timestamps[0::decimation] + new_timestamps[ + ts_offset : ts_offset + len(extracted_ts) + ] = extracted_ts + ts_offset += len(extracted_ts) + input_index_bounds = ([0, interval_samples - 1],) + + else: + print(f"Interval {ii}: leaving data on disk") + data = data_on_disk + timestamps = timestamps_on_disk + extracted_ts = timestamps[start:stop:decimation] + new_timestamps[ + ts_offset : ts_offset + len(extracted_ts) + ] = extracted_ts + ts_offset += len(extracted_ts) + input_index_bounds = [start, stop] + + # filter the data + gsp.filter_data_fir( + data, + filter_coeff, + axis=time_axis, + input_index_bounds=input_index_bounds, + output_index_bounds=[ + filter_delay, + filter_delay + stop - start, + ], + ds=decimation, + input_dim_restrictions=input_dim_restrictions, + outarray=filtered_data, + output_offset=output_offsets[ii], + ) - start_end = [new_timestamps[0], new_timestamps[-1]] + start_end = [new_timestamps[0], new_timestamps[-1]] - io.write(nwbf) + io.write(nwbf) return es.object_id, start_end @@ -366,14 +418,26 @@ def filter_data( decimation, ): """ - :param timestamps: numpy array with list of timestamps for data - :param data: original data array - :param filter_coeff: numpy array with filter coefficients for FIR filter - :param valid_times: 2D numpy array with start and stop times of intervals to be filtered - :param electrodes: list of electrodes to filter - :param decimation: decimation factor - :return: filtered_data, timestamps + Parameters + ---------- + timestamps: numpy array + List of timestamps for data + data: + original data array + filter_coeff: numpy array + Filter coefficients for FIR filter + valid_times: 2D numpy array + Start and stop times of intervals to be filtered + electrodes: list + Electrodes to filter + decimation: + decimation factor + + Return + ------ + filtered_data, timestamps """ + gsp = _import_ghostipy() n_dim = len(data.shape) @@ -390,24 +454,13 @@ def filter_data( filter_delay = self.calc_filter_delay(filter_coeff) for a_start, a_stop in valid_times: - if a_start < timestamps[0]: - print( - f"Interval start time {a_start} is smaller than first timestamp " - f"{timestamps[0]}, using first timestamp instead" - ) - a_start = timestamps[0] - if a_stop > timestamps[-1]: - print( - f"Interval stop time {a_stop} is larger than last timestamp " - f"{timestamps[-1]}, using last timestamp instead" - ) - a_stop = timestamps[-1] - frm, to = np.searchsorted(timestamps, (a_start, a_stop)) - if to > n_samples: - to = n_samples + frm, to = self._time_bound_check( + a_start, a_stop, timestamps, n_samples + ) indices.append((frm, to)) - shape, dtype = gsp.filter_data_fir( + + shape, _ = gsp.filter_data_fir( data, filter_coeff, axis=time_axis, @@ -458,14 +511,20 @@ def filter_data( def calc_filter_delay(self, filter_coeff): """ - :param filter_coeff: - :return: filter delay + Parameters + ---------- + filter_coeff: numpy array + + Return + ------ + filter delay: int """ return (len(filter_coeff) - 1) // 2 def create_standard_filters(self): - """Add standard filters to the Filter table including - 0-400 Hz low pass for continuous raw data -> LFP + """Add standard filters to the Filter table + + Includes 0-400 Hz low pass for continuous raw data -> LFP """ self.add_filter( "LFP 0-400 Hz", diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index 9aa184e9e..804723685 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -20,42 +20,45 @@ class IntervalList(dj.Manual): -> Session interval_list_name: varchar(200) # descriptive name of this interval list --- - valid_times: longblob # numpy array with start and end times for each interval + valid_times: longblob # numpy array with start/end times for each interval """ @classmethod def insert_from_nwbfile(cls, nwbf, *, nwb_file_name): - """Add each entry in the NWB file epochs table to the IntervalList table. + """Add each entry in the NWB file epochs table to the IntervalList. - The interval list name for each epoch is set to the first tag for the epoch. - If the epoch has no tags, then 'interval_x' will be used as the interval list name, where x is the index - (0-indexed) of the epoch in the epochs table. - The start time and stop time of the epoch are stored in the valid_times field as a numpy array of - [start time, stop time] for each epoch. + The interval list name for each epoch is set to the first tag for the + epoch. If the epoch has no tags, then 'interval_x' will be used as the + interval list name, where x is the index (0-indexed) of the epoch in the + epochs table. The start time and stop time of the epoch are stored in + the valid_times field as a numpy array of [start time, stop time] for + each epoch. Parameters ---------- nwbf : pynwb.NWBFile The source NWB file object. nwb_file_name : str - The file name of the NWB file, used as a primary key to the Session table. + The file name of the NWB file, used as a primary key to the Session + table. """ if nwbf.epochs is None: print("No epochs found in NWB file.") return + epochs = nwbf.epochs.to_dataframe() - for epoch_index, epoch_data in epochs.iterrows(): - epoch_dict = dict() - epoch_dict["nwb_file_name"] = nwb_file_name - if epoch_data.tags[0]: - epoch_dict["interval_list_name"] = epoch_data.tags[0] - else: - epoch_dict["interval_list_name"] = "interval_" + str( - epoch_index - ) - epoch_dict["valid_times"] = np.asarray( - [[epoch_data.start_time, epoch_data.stop_time]] - ) + + for _, epoch_data in epochs.iterrows(): + epoch_dict = { + "nwb_file_name": nwb_file_name, + "interval_list_name": epoch_data[1].tags[0] + if epoch_data[1].tags + else f"interval_{epoch_data[0]}", + "valid_times": np.asarray( + [[epoch_data[1].start_time, epoch_data[1].stop_time]] + ), + } + cls.insert1(epoch_dict, skip_duplicates=True) def plot_intervals(self, figsize=(20, 5)): @@ -145,7 +148,7 @@ def intervals_by_length(interval_list, min_length=0.0, max_length=1e10): Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval. Unit is seconds. + Each element is (start time, stop time), i.e. an interval in seconds. min_length : float, optional Minimum interval length in seconds. Defaults to 0.0. max_length : float, optional @@ -158,12 +161,12 @@ def intervals_by_length(interval_list, min_length=0.0, max_length=1e10): def interval_list_contains_ind(interval_list, timestamps): - """Find indices of a list of timestamps that are contained in an interval list. + """Find indices of list of timestamps contained in an interval list. Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval. Unit is seconds. + Each element is (start time, stop time), i.e. an interval in seconds. timestamps : array_like """ ind = [] @@ -184,7 +187,7 @@ def interval_list_contains(interval_list, timestamps): Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval. Unit is seconds. + Each element is (start time, stop time), i.e. an interval in seconds. timestamps : array_like """ ind = [] @@ -200,28 +203,17 @@ def interval_list_contains(interval_list, timestamps): def interval_list_excludes_ind(interval_list, timestamps): - """Find indices of a list of timestamps that are not contained in an interval list. + """Find indices of timestamps that are not contained in an interval list. Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval. Unit is seconds. + Each element is (start time, stop time), i.e. an interval in seconds. timestamps : array_like """ contained_inds = interval_list_contains_ind(interval_list, timestamps) return np.setdiff1d(np.arange(len(timestamps)), contained_inds) - # # add the first and last times to the list and creat a list of invalid intervals - # valid_times_list = np.ndarray.ravel(interval_list).tolist() - # valid_times_list.insert(0, timestamps[0] - 0.00001) - # valid_times_list.append(timestamps[-1] + 0.001) - # invalid_times = np.array(valid_times_list).reshape(-1, 2) - # # add the first and last timestamp indices - # ind = [] - # for invalid_time in invalid_times: - # ind += np.ravel(np.argwhere(np.logical_and(timestamps > invalid_time[0], - # timestamps < invalid_time[1]))).tolist() - # return np.asarray(ind) def interval_list_excludes(interval_list, timestamps): @@ -230,22 +222,24 @@ def interval_list_excludes(interval_list, timestamps): Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval. Unit is seconds. + Each element is (start time, stop time), i.e. an interval in seconds. timestamps : array_like """ contained_times = interval_list_contains(interval_list, timestamps) return np.setdiff1d(timestamps, contained_times) - # # add the first and last times to the list and creat a list of invalid intervals - # valid_times_list = np.ravel(valid_times).tolist() - # valid_times_list.insert(0, timestamps[0] - 0.00001) - # valid_times_list.append(timestamps[-1] + 0.00001) - # invalid_times = np.array(valid_times_list).reshape(-1, 2) - # # add the first and last timestamp indices - # ind = [] - # for invalid_time in invalid_times: - # ind += np.ravel(np.argwhere(np.logical_and(timestamps > invalid_time[0], - # timestamps < invalid_time[1]))).tolist() - # return timestamps[ind] + + +def consolidate_intervals(interval_list): + if interval_list.ndim == 1: + interval_list = np.expand_dims(interval_list, 0) + else: + interval_list = interval_list[np.argsort(interval_list[:, 0])] + interval_list = reduce(_union_concat, interval_list) + # the following check is needed in the case where the interval list is a + # single element (behavior of reduce) + if interval_list.ndim == 1: + interval_list = np.expand_dims(interval_list, 0) + return interval_list def interval_list_intersect(interval_list1, interval_list2, min_length=0): @@ -265,76 +259,51 @@ def interval_list_intersect(interval_list1, interval_list2, min_length=0): interval_list: np.array, (N,2) """ - # first, consolidate interval lists to disjoint intervals by sorting and applying union - if interval_list1.ndim == 1: - interval_list1 = np.expand_dims(interval_list1, 0) - else: - interval_list1 = interval_list1[np.argsort(interval_list1[:, 0])] - interval_list1 = reduce(_union_concat, interval_list1) - # the following check is needed in the case where the interval list is a single element (behavior of reduce) - if interval_list1.ndim == 1: - interval_list1 = np.expand_dims(interval_list1, 0) - - if interval_list2.ndim == 1: - interval_list2 = np.expand_dims(interval_list2, 0) - else: - interval_list2 = interval_list2[np.argsort(interval_list2[:, 0])] - interval_list2 = reduce(_union_concat, interval_list2) - # the following check is needed in the case where the interval list is a single element (behavior of reduce) - if interval_list2.ndim == 1: - interval_list2 = np.expand_dims(interval_list2, 0) + # Consolidate interval lists to disjoint int'ls by sorting & applying union + interval_list1 = consolidate_intervals(interval_list1) + interval_list2 = consolidate_intervals(interval_list2) # then do pairwise comparison and collect intersections - intersecting_intervals = [] - for interval2 in interval_list2: - for interval1 in interval_list1: - if _intersection(interval2, interval1) is not None: - intersecting_intervals.append( - _intersection(interval1, interval2) - ) + intersecting_intervals = [ + _intersection(interval2, interval1) + for interval2 in interval_list2 + for interval1 in interval_list1 + if _intersection(interval2, interval1) is not None + ] # if no intersection, then return an empty list if not intersecting_intervals: return [] - else: - intersecting_intervals = np.asarray(intersecting_intervals) - intersecting_intervals = intersecting_intervals[ - np.argsort(intersecting_intervals[:, 0]) - ] - return intervals_by_length( - intersecting_intervals, min_length=min_length - ) + intersecting_intervals = np.asarray(intersecting_intervals) + intersecting_intervals = intersecting_intervals[ + np.argsort(intersecting_intervals[:, 0]) + ] + + return intervals_by_length(intersecting_intervals, min_length=min_length) def _intersection(interval1, interval2): - "Takes the (set-theoretic) intersection of two intervals" - intersection = np.array( - [max([interval1[0], interval2[0]]), min([interval1[1], interval2[1]])] - ) - if intersection[1] > intersection[0]: - return intersection - else: - return None + """Takes the (set-theoretic) intersection of two intervals""" + start = max(interval1[0], interval2[0]) + end = min(interval1[1], interval2[1]) + intersection = np.array([start, end]) if end > start else None + return intersection def _union(interval1, interval2): - "Takes the (set-theoretic) union of two intervals" + """Takes the (set-theoretic) union of two intervals""" if _intersection(interval1, interval2) is None: return np.array([interval1, interval2]) - else: - return np.array( - [ - min([interval1[0], interval2[0]]), - max([interval1[1], interval2[1]]), - ] - ) + return np.array( + [min(interval1[0], interval2[0]), max(interval1[1], interval2[1])] + ) def _union_concat(interval_list, interval): - """Compares the last interval of the interval list to the given interval and - * takes their union if overlapping - * concatenates the interval to the interval list if not + """Compare last interval of interval list to given interval. + + If overlapping, take union. If not, concatenate interval to interval list. Recursively called with `reduce`. """ @@ -344,27 +313,23 @@ def _union_concat(interval_list, interval): interval = np.expand_dims(interval, 0) x = _union(interval_list[-1], interval[0]) - if x.ndim == 1: - x = np.expand_dims(x, 0) + x = np.expand_dims(x, 0) if x.ndim == 1 else x + return np.concatenate((interval_list[:-1], x), axis=0) def union_adjacent_index(interval1, interval2): - """unions two intervals that are adjacent in index + """Union index-adjacent intervals. If not adjacent, just concatenate. + e.g. [a,b] and [b+1, c] is converted to [a,c] - if not adjacent, just concatenates interval2 at the end of interval1 Parameters ---------- interval1 : np.array - [description] interval2 : np.array - [description] """ - if interval1.ndim == 1: - interval1 = np.expand_dims(interval1, 0) - if interval2.ndim == 1: - interval2 = np.expand_dims(interval2, 0) + interval1 = np.atleast_2d(interval1) + interval2 = np.atleast_2d(interval2) if ( interval1[-1][1] + 1 == interval2[0][0] @@ -386,50 +351,63 @@ def union_adjacent_index(interval1, interval2): # TODO: test interval_list_union code +def _parallel_union(interval_list): + """Create a parallel list where 1 is start and -1 the end""" + interval_list = np.ravel(interval_list) + interval_list_start_end = np.ones(interval_list.shape) + interval_list_start_end[1::2] = -1 + return interval_list, interval_list_start_end + + def interval_list_union( - interval_list1, interval_list2, min_length=0.0, max_length=1e10 -): + interval_list1: np.ndarray, + interval_list2: np.ndarray, + min_length: float = 0.0, + max_length: float = 1e10, +) -> np.ndarray: """Finds the union (all times in one or both) for two interval lists - :param interval_list1: The first interval list - :type interval_list1: numpy array of intervals [start, stop] - :param interval_list2: The second interval list - :type interval_list2: numpy array of intervals [start, stop] - :param min_length: optional minimum length of interval for inclusion in output, default 0.0 - :type min_length: float - :param max_length: optional maximum length of interval for inclusion in output, default 1e10 - :type max_length: float - :return: interval_list - :rtype: numpy array of intervals [start, stop] + Parameters + ---------- + interval_list1 : np.ndarray + The first interval list [start, stop] + interval_list2 : np.ndarray + The second interval list [start, stop] + min_length : float, optional + Minimum length of interval for inclusion in output, default 0.0 + max_length : float, optional + Maximum length of interval for inclusion in output, default 1e10 + + Returns + ------- + np.ndarray + Array of intervals [start, stop] """ - # return np.array([min(interval_list1[0],interval_list2[0]), - # max(interval_list1[1],interval_list2[1])]) - interval_list1 = np.ravel(interval_list1) - # create a parallel list where 1 indicates the start and -1 the end of an interval - interval_list1_start_end = np.ones(interval_list1.shape) - interval_list1_start_end[1::2] = -1 - - interval_list2 = np.ravel(interval_list2) - # create a parallel list for the second interval where 1 indicates the start and -1 the end of an interval - interval_list2_start_end = np.ones(interval_list2.shape) - interval_list2_start_end[1::2] = -1 - - # concatenate the two lists so we can resort the intervals and apply the same sorting to the start-end arrays - combined_intervals = np.concatenate((interval_list1, interval_list2)) - ss = np.concatenate((interval_list1_start_end, interval_list2_start_end)) + + il1, il1_start_end = _parallel_union(interval_list1) + il2, il2_start_end = _parallel_union(interval_list2) + + # Concatenate the two lists so we can resort the intervals and apply the + # same sorting to the start-end arrays + combined_intervals = np.concatenate((il1, il2)) + ss = np.concatenate((il1_start_end, il2_start_end)) sort_ind = np.argsort(combined_intervals) combined_intervals = combined_intervals[sort_ind] - # a cumulative sum of 1 indicates the beginning of a joint interval; a cumulative sum of 0 indicates the end + + # a cumulative sum of 1 indicates the beginning of a joint interval; a + # cumulative sum of 0 indicates the end union_starts = np.ravel(np.array(np.where(np.cumsum(ss[sort_ind]) == 1))) union_stops = np.ravel(np.array(np.where(np.cumsum(ss[sort_ind]) == 0))) - union = [] - for start, stop in zip(union_starts, union_stops): - union.append([combined_intervals[start], combined_intervals[stop]]) + union = [ + [combined_intervals[start], combined_intervals[stop]] + for start, stop in zip(union_starts, union_stops) + ] + return np.asarray(union) def interval_list_censor(interval_list, timestamps): - """returns a new interval list that starts and ends at the first and last timestamp + """Returns new interval list that starts/ends at first/last timestamp Parameters ---------- @@ -442,9 +420,10 @@ def interval_list_censor(interval_list, timestamps): interval_list (numpy array of intervals [start, stop]) """ # check that all timestamps are in the interval list - assert len(interval_list_contains_ind(interval_list, timestamps)) == len( + if len(interval_list_contains_ind(interval_list, timestamps)) != len( timestamps - ), "interval_list must contain all timestamps" + ): + raise ValueError("Interval_list must contain all timestamps") timestamps_interval = np.asarray([[timestamps[0], timestamps[-1]]]) return interval_list_intersect(interval_list, timestamps_interval) @@ -452,6 +431,7 @@ def interval_list_censor(interval_list, timestamps): def interval_from_inds(list_frames): """Converts a list of indices to a list of intervals. + e.g. [2,3,4,6,7,8,9,10] -> [[2,4],[6,10]] Parameters diff --git a/src/spyglass/common/common_lab.py b/src/spyglass/common/common_lab.py index 9be9d15b2..712384273 100644 --- a/src/spyglass/common/common_lab.py +++ b/src/spyglass/common/common_lab.py @@ -25,8 +25,8 @@ class LabMemberInfo(dj.Part): # Information about lab member in the context of Frank lab network -> LabMember --- - google_user_name: varchar(200) # used for permission to curate - datajoint_user_name = "": varchar(200) # used for permission to delete entries + google_user_name: varchar(200) # For permission to curate + datajoint_user_name = "": varchar(200) # For permission to delete ns """ @classmethod @@ -108,9 +108,10 @@ def create_new_team( team_description: str The description of the team. """ - labteam_dict = dict() - labteam_dict["team_name"] = team_name - labteam_dict["team_description"] = team_description + labteam_dict = { + "team_name": team_name, + "team_description": team_description, + } cls.insert1(labteam_dict, skip_duplicates=True) for team_member in team_members: @@ -120,12 +121,13 @@ def create_new_team( ).fetch("google_user_name") if not query: print( - f"Please add the Google user ID for {team_member} in the " - + "LabMember.LabMemberInfo table to help manage permissions." + f"Please add the Google user ID for {team_member} in " + + "LabMember.LabMemberInfo to help manage permissions." ) - labteammember_dict = dict() - labteammember_dict["team_name"] = team_name - labteammember_dict["lab_member_name"] = team_member + labteammember_dict = { + "team_name": team_name, + "lab_member_name": team_member, + } cls.LabTeamMember.insert1(labteammember_dict, skip_duplicates=True) @@ -133,7 +135,6 @@ def create_new_team( class Institution(dj.Manual): definition = """ institution_name: varchar(80) - --- """ @classmethod @@ -148,6 +149,7 @@ def insert_from_nwbfile(cls, nwbf): if nwbf.institution is None: print("No institution metadata found.\n") return + cls.insert1( dict(institution_name=nwbf.institution), skip_duplicates=True ) @@ -157,7 +159,6 @@ def insert_from_nwbfile(cls, nwbf): class Lab(dj.Manual): definition = """ lab_name: varchar(80) - --- """ @classmethod @@ -198,5 +199,7 @@ def decompose_name(full_name: str) -> tuple: last, first = full_name.title().split(", ") else: first, last = full_name.title().split(" ") + full = f"{first} {last}" + return full, first, last diff --git a/src/spyglass/common/common_region.py b/src/spyglass/common/common_region.py index b21e99cfa..97d8f26c8 100644 --- a/src/spyglass/common/common_region.py +++ b/src/spyglass/common/common_region.py @@ -13,16 +13,18 @@ class BrainRegion(dj.Lookup): subsubregion_name=NULL: varchar(200) # subregion within subregion """ - # TODO consider making (region_name, subregion_name, subsubregion_name) a primary key - # subregion_name='' and subsubregion_name='' will be necessary but that seems OK + # TODO consider making (region_name, subregion_name, subsubregion_name) a + # primary key subregion_name='' and subsubregion_name='' will be necessary + # but that seems OK @classmethod def fetch_add( cls, region_name, subregion_name=None, subsubregion_name=None ): - """Return the region ID for the given names, and if no match exists, first add it to the BrainRegion table. + """Return the region ID for names. If no match, add to the BrainRegion. - The combination of (region_name, subregion_name, subsubregion_name) is effectively unique, then. + The combination of (region_name, subregion_name, subsubregion_name) is + effectively unique, then. Parameters ---------- @@ -38,12 +40,10 @@ def fetch_add( region_id : int The index of the region in the BrainRegion table. """ - key = dict() - key["region_name"] = region_name - key["subregion_name"] = subregion_name - key["subsubregion_name"] = subsubregion_name - query = BrainRegion & key - if not query: - cls.insert1(key) - query = BrainRegion & key - return query.fetch1("region_id") + key = dict( + region_name=region_name, + subregion_name=subregion_name, + subsubregion_name=subsubregion_name, + ) + cls.insert1(key, skip_duplicates=True) + return (BrainRegion & key).fetch1("region_id") diff --git a/src/spyglass/spikesorting/spikesorting_recording.py b/src/spyglass/spikesorting/spikesorting_recording.py index 593ff8db1..cdad793e9 100644 --- a/src/spyglass/spikesorting/spikesorting_recording.py +++ b/src/spyglass/spikesorting/spikesorting_recording.py @@ -21,6 +21,7 @@ from ..common.common_nwbfile import Nwbfile from ..common.common_session import Session # noqa: F401 from ..utils.dj_helper_fn import dj_replace +from ..settings import recording_dir schema = dj.schema("spikesorting_recording") @@ -379,59 +380,67 @@ def make(self, key): recording = self._get_filtered_recording(key) recording_name = self._get_recording_name(key) - tmp_key = { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": recording_name, - "valid_times": sort_interval_valid_times, - } - IntervalList.insert1(tmp_key, replace=True) + # Path to files that will hold the recording extractors + recording_path = str(recording_dir / Path(recording_name)) + if os.path.exists(recording_path): + shutil.rmtree(recording_path) - # store the list of valid times for the sort - key["sort_interval_list_name"] = tmp_key["interval_list_name"] + recording.save( + folder=recording_path, chunk_duration="10000ms", n_jobs=8 + ) - # Path to files that will hold the recording extractors - recording_folder = Path(os.getenv("SPYGLASS_RECORDING_DIR")) - key["recording_path"] = str(recording_folder / Path(recording_name)) - if os.path.exists(key["recording_path"]): - shutil.rmtree(key["recording_path"]) - recording = recording.save( - folder=key["recording_path"], chunk_duration="10000ms", n_jobs=8 + IntervalList.insert1( + { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": recording_name, + "valid_times": sort_interval_valid_times, + }, + replace=True, ) - self.insert1(key) + self.insert1( + { + **key, + # store the list of valid times for the sort + "sort_interval_list_name": recording_name, + "recording_path": recording_path, + } + ) @staticmethod def _get_recording_name(key): - recording_name = ( - key["nwb_file_name"] - + "_" - + key["sort_interval_name"] - + "_" - + str(key["sort_group_id"]) - + "_" - + key["preproc_params_name"] + return "_".join( + [ + key["nwb_file_name"], + key["sort_interval_name"], + str(key["sort_group_id"]), + key["preproc_params_name"], + ] ) - return recording_name @staticmethod def _get_recording_timestamps(recording): - if recording.get_num_segments() > 1: - frames_per_segment = [0] - for i in range(recording.get_num_segments()): - frames_per_segment.append( - recording.get_num_frames(segment_index=i) - ) + num_segments = recording.get_num_segments() - cumsum_frames = np.cumsum(frames_per_segment) - total_frames = np.sum(frames_per_segment) + if num_segments <= 1: + return recording.get_times() + + frames_per_segment = [0] + [ + recording.get_num_frames(segment_index=i) + for i in range(num_segments) + ] + + cumsum_frames = np.cumsum(frames_per_segment) + total_frames = np.sum(frames_per_segment) + + timestamps = np.zeros((total_frames,)) + for i in range(num_segments): + start_index = cumsum_frames[i] + end_index = cumsum_frames[i + 1] + timestamps[start_index:end_index] = recording.get_times( + segment_index=i + ) - timestamps = np.zeros((total_frames,)) - for i in range(recording.get_num_segments()): - timestamps[ - cumsum_frames[i] : cumsum_frames[i + 1] - ] = recording.get_times(segment_index=i) - else: - timestamps = recording.get_times() return timestamps def _get_sort_interval_valid_times(self, key): @@ -456,9 +465,11 @@ def _get_sort_interval_valid_times(self, key): "sort_interval_name": key["sort_interval_name"], } ).fetch1("sort_interval") + interval_list_name = (SpikeSortingRecordingSelection & key).fetch1( "interval_list_name" ) + valid_interval_times = ( IntervalList & { @@ -466,6 +477,7 @@ def _get_sort_interval_valid_times(self, key): "interval_list_name": interval_list_name, } ).fetch1("valid_times") + valid_sort_times = interval_list_intersect( sort_interval, valid_interval_times )