diff --git a/.github/workflows/test-conda.yml b/.github/workflows/test-conda.yml index cd793a480..594a7b2b8 100644 --- a/.github/workflows/test-conda.yml +++ b/.github/workflows/test-conda.yml @@ -17,16 +17,6 @@ jobs: env: OS: ${{ matrix.os }} PYTHON: '3.8' - # SPYGLASS_BASE_DIR: ./data - # KACHERY_STORAGE_DIR: ./data/kachery-storage - # DJ_SUPPORT_FILEPATH_MANAGEMENT: True - # services: - # datajoint_test_server: - # image: datajoint/mysql - # ports: - # - 3306:3306 - # options: >- - # -e MYSQL_ROOT_PASSWORD=tutorial steps: - name: Cancel Workflow Action uses: styfle/cancel-workflow-action@0.11.0 @@ -49,6 +39,17 @@ jobs: - name: Install spyglass run: | pip install -e .[test] + - name: Download data + env: + UCSF_BOX_TOKEN: ${{ secrets.UCSF_BOX_TOKEN }} + UCSF_BOX_USER: ${{ secrets.UCSF_BOX_USER }} + WEBSITE: ftps://ftp.box.com/trodes_to_nwb_test_data/minirec20230622.nwb + RAW_DIR: /home/runner/work/spyglass/spyglass/tests/_data/raw/ + run: | + mkdir -p $RAW_DIR + wget --recursive --no-verbose --no-host-directories --no-directories \ + --user $UCSF_BOX_USER --password $UCSF_BOX_TOKEN \ + -P $RAW_DIR $WEBSITE - name: Run tests run: | pytest -rP # env vars are set within certain tests diff --git a/CHANGELOG.md b/CHANGELOG.md index 895702b43..302c116d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ - Add `deprecation_factory` to facilitate table migration. #717 - Add Spyglass logger. #730 - IntervalList: Add secondary key `pipeline` #742 +- Increase pytest coverage for `common`, `lfp`, and `utils`. #743 ### Pipelines @@ -31,7 +32,6 @@ - Allow multiple spike waveform features for clusterelss decoding #731 - Reorder notebooks #731 - ## [0.4.3] (November 7, 2023) - Migrate `config` helper scripts to Spyglass codebase. #662 diff --git a/pyproject.toml b/pyproject.toml index 33a7df931..521224737 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ spyglass_cli = "spyglass.cli:cli" [project.optional-dependencies] position = ["ffmpeg", "numba>=0.54", "deeplabcut<2.3.0"] test = [ + "docker", # for tests in a container "pytest", # unit testing "pytest-cov", # code coverage "kachery", # database access @@ -109,5 +110,41 @@ line-length = 80 [tool.codespell] skip = '.git,*.pdf,*.svg,*.ipynb,./docs/site/**,temp*' -# Nevers - name in Citation ignore-words-list = 'nevers' +# Nevers - name in Citation + +[tool.pytest.ini_options] +minversion = "7.0" +addopts = [ + "-sv", + "-p no:warnings", + # "--no-teardown", # don't teardown the database after tests + # "--quiet-spy", # don't show logging from spyglass + "--show-capture=no", + "--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger + "--cov=spyglass", + "--cov-report=term-missing", + "--no-cov-on-fail", +] +testpaths = ["tests"] +log_level = "INFO" + +[tool.coverage.run] +source = ["*/src/spyglass/*"] +omit = [ # which submodules have no tests + "*/__init__.py", + "*/_version.py", + "*/cli/*", + # "*/common/*", + "*/data_import/*", + "*/decoding/*", + "*/figurl_views/*", + # "*/lfp/*", + "*/linearization/*", + "*/lock/*", + "*/position/*", + "*/ripple/*", + "*/sharing/*", + "*/spikesorting/*", + # "*/utils/*", +] diff --git a/src/spyglass/common/common_device.py b/src/spyglass/common/common_device.py index 223862c81..2dd03c822 100644 --- a/src/spyglass/common/common_device.py +++ b/src/spyglass/common/common_device.py @@ -2,8 +2,8 @@ import ndx_franklab_novela from spyglass.common.errors import PopulateException -from spyglass.utils.dj_mixin import SpyglassMixin -from spyglass.utils.logging import logger +from spyglass.settings import test_mode +from spyglass.utils import SpyglassMixin, logger from spyglass.utils.nwb_helper_fn import get_nwb_file schema = dj.schema("common_device") @@ -154,25 +154,9 @@ def _add_device(cls, new_device_dict): all_values = DataAcquisitionDevice.fetch( "data_acquisition_device_name" ).tolist() - if name not in all_values: - # no entry with the same name exists, prompt user to add a new entry - logger.info( - f"\nData acquisition device '{name}' was not found in the " - f"database. The current values are: {all_values}. " - "Please ensure that the device you want to add does not already" - " exist in the database under a different name or spelling. " - "If you want to use an existing device in the database, " - "please change the corresponding Device object in the NWB file." - " Entering 'N' will raise an exception." - ) - to_db = " to the database" - val = input(f"Add data acquisition device '{name}'{to_db}? (y/N)") - if val.lower() in ["y", "yes"]: - cls.insert1(new_device_dict, skip_duplicates=True) - return - raise PopulateException( - f"User chose not to add device '{name}'{to_db}." - ) + if prompt_insert(name=name, all_values=all_values): + cls.insert1(new_device_dict, skip_duplicates=True) + return # Check if values provided match the values stored in the database db_dict = ( @@ -213,28 +197,11 @@ def _add_system(cls, system): all_values = DataAcquisitionDeviceSystem.fetch( "data_acquisition_device_system" ).tolist() - if system not in all_values: - logger.info( - f"\nData acquisition device system '{system}' was not found in" - f" the database. The current values are: {all_values}. " - "Please ensure that the system you want to add does not already" - " exist in the database under a different name or spelling. " - "If you want to use an existing system in the database, " - "please change the corresponding Device object in the NWB file." - " Entering 'N' will raise an exception." - ) - val = input( - f"Do you want to add data acquisition device system '{system}'" - + " to the database? (y/N)" - ) - if val.lower() in ["y", "yes"]: - key = {"data_acquisition_device_system": system} - DataAcquisitionDeviceSystem.insert1(key, skip_duplicates=True) - else: - raise PopulateException( - "User chose not to add data acquisition device system " - + f"'{system}' to the database." - ) + if prompt_insert( + name=system, all_values=all_values, table_type="system" + ): + key = {"data_acquisition_device_system": system} + DataAcquisitionDeviceSystem.insert1(key, skip_duplicates=True) return system @classmethod @@ -264,30 +231,11 @@ def _add_amplifier(cls, amplifier): all_values = DataAcquisitionDeviceAmplifier.fetch( "data_acquisition_device_amplifier" ).tolist() - if amplifier not in all_values: - logger.info( - f"\nData acquisition device amplifier '{amplifier}' was not " - f"found in the database. The current values are: {all_values}. " - "Please ensure that the amplifier you want to add does not " - "already exist in the database under a different name or " - "spelling. If you want to use an existing name in the database," - " please change the corresponding Device object in the NWB " - "file. Entering 'N' will raise an exception." - ) - val = input( - "Do you want to add data acquisition device amplifier " - + f"'{amplifier}' to the database? (y/N)" - ) - if val.lower() in ["y", "yes"]: - key = {"data_acquisition_device_amplifier": amplifier} - DataAcquisitionDeviceAmplifier.insert1( - key, skip_duplicates=True - ) - else: - raise PopulateException( - "User chose not to add data acquisition device amplifier " - + f"'{amplifier}' to the database." - ) + if prompt_insert( + name=amplifier, all_values=all_values, table_type="amplifier" + ): + key = {"data_acquisition_device_amplifier": amplifier} + DataAcquisitionDeviceAmplifier.insert1(key, skip_duplicates=True) return amplifier @@ -576,27 +524,9 @@ def _add_probe_type(cls, new_probe_type_dict): """ probe_type = new_probe_type_dict["probe_type"] all_values = ProbeType.fetch("probe_type").tolist() - if probe_type not in all_values: - logger.info( - f"\nProbe type '{probe_type}' was not found in the database. " - f"The current values are: {all_values}. " - "Please ensure that the probe type you want to add does not " - "already exist in the database under a different name or " - "spelling. If you want to use an existing name in the " - "database, please change the corresponding Probe object in the " - "NWB file. Entering 'N' will raise an exception." - ) - val = input( - f"Do you want to add probe type '{probe_type}' to the database?" - + " (y/N)" - ) - if val.lower() in ["y", "yes"]: - ProbeType.insert1(new_probe_type_dict, skip_duplicates=True) - return - raise PopulateException( - f"User chose not to add probe type '{probe_type}' to the " - + "database." - ) + if prompt_insert(probe_type, all_values, table="probe type"): + ProbeType.insert1(new_probe_type_dict, skip_duplicates=True) + return # else / entry exists: check whether the values provided match the # values stored in the database @@ -738,3 +668,55 @@ def create_from_nwbfile( cls.Shank.insert1(shank, skip_duplicates=True) for electrode in elect_dict.values(): cls.Electrode.insert1(electrode, skip_duplicates=True) + + +# ---------------------------- Helper functions ---------------------------- + + +# Migrated down to reduce redundancy and centralize 'test_mode' check for pytest +def prompt_insert( + name: str, + all_values: list, + table: str = "Data Acquisition Device", + table_type: str = None, +) -> bool: + """Prompt user to add an item to the database. Return True if yes. + + Assume insert during test mode. + + Parameters + ---------- + name : str + The name of the item to add. + all_values : list + List of all values in the database. + table : str, optional + The name of the table to add to, by default Data Acquisition Device + table_type : str, optional + The type of item to add, by default None. Data Acquisition Device X + """ + if name in all_values: + return False + + if test_mode: + return True + + if table_type: + table_type += " " + + logger.info( + f"{table}{table_type} '{name}' was not found in the" + f"database. The current values are: {all_values}.\n" + "Please ensure that the device you want to add does not already" + "exist in the database under a different name or spelling. If you" + "want to use an existing device in the database, please change the" + "corresponding Device object in the NWB file.\nEntering 'N' will " + "raise an exception." + ) + msg = f"Do you want to add {table}{table_type} '{name}' to the database?" + if dj.utils.user_choice(msg).lower() in ["y", "yes"]: + return True + + raise PopulateException( + f"User chose not to add {table}{table_type} '{name}' to the database." + ) diff --git a/src/spyglass/common/common_dio.py b/src/spyglass/common/common_dio.py index 93a087116..7eae1e9d3 100644 --- a/src/spyglass/common/common_dio.py +++ b/src/spyglass/common/common_dio.py @@ -50,7 +50,7 @@ def make(self, key): key["dio_object_id"] = event_series.object_id self.insert1(key, skip_duplicates=True) - def plot_all_dio_events(self): + def plot_all_dio_events(self, return_fig=False): """Plot all DIO events in the session. Examples @@ -117,3 +117,6 @@ def plot_all_dio_events(self): plt.suptitle(f"DIO events in {nwb_file_names[0]}") else: plt.suptitle(f"DIO events in {', '.join(nwb_file_names)}") + + if return_fig: + return plt.gcf() diff --git a/src/spyglass/common/common_filter.py b/src/spyglass/common/common_filter.py index 0472c6e18..9d2cdf9d6 100644 --- a/src/spyglass/common/common_filter.py +++ b/src/spyglass/common/common_filter.py @@ -167,9 +167,9 @@ def add_filter( def _filter_restrict(self, filter_name, fs): return ( self & {"filter_name": filter_name} & {"filter_sampling_rate": fs} - ).fetch1(as_dict=True) + ).fetch1() - def plot_magnitude(self, filter_name, fs): + def plot_magnitude(self, filter_name, fs, return_fig=False): filter_dict = self._filter_restrict(filter_name, fs) plt.figure() w, h = signal.freqz(filter_dict["filter_coeff"], worN=65536) @@ -178,11 +178,13 @@ def plot_magnitude(self, filter_name, fs): plt.xlabel("Frequency (Hz)") plt.ylabel("Magnitude") plt.title("Frequency Response") - plt.xlim(0, np.max(filter_dict["filter_coeffand_edges"] * 2)) + plt.xlim(0, np.max(filter_dict["filter_band_edges"] * 2)) plt.ylim(np.min(magnitude), -1 * np.min(magnitude) * 0.1) plt.grid(True) + if return_fig: + return plt.gcf() - def plot_fir_filter(self, filter_name, fs): + def plot_fir_filter(self, filter_name, fs, return_fig=False): filter_dict = self._filter_restrict(filter_name, fs) plt.figure() plt.clf() @@ -191,6 +193,8 @@ def plot_fir_filter(self, filter_name, fs): plt.ylabel("Magnitude") plt.title("Filter Taps") plt.grid(True) + if return_fig: + return plt.gcf() def filter_delay(self, filter_name, fs): return self.calc_filter_delay( diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index b03055f88..d754261fc 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -66,7 +66,7 @@ def insert_from_nwbfile(cls, nwbf, *, nwb_file_name): cls.insert1(epoch_dict, skip_duplicates=True) - def plot_intervals(self, figsize=(20, 5)): + def plot_intervals(self, figsize=(20, 5), return_fig=False): interval_list = pd.DataFrame(self) fig, ax = plt.subplots(figsize=figsize) interval_count = 0 @@ -84,8 +84,10 @@ def plot_intervals(self, figsize=(20, 5)): ax.set_yticklabels(interval_list.interval_list_name) ax.set_xlabel("Time [s]") ax.grid(True) + if return_fig: + return fig - def plot_epoch_pos_raw_intervals(self, figsize=(20, 5)): + def plot_epoch_pos_raw_intervals(self, figsize=(20, 5), return_fig=False): interval_list = pd.DataFrame(self) fig, ax = plt.subplots(figsize=(30, 3)) @@ -145,6 +147,8 @@ def plot_epoch_pos_raw_intervals(self, figsize=(20, 5)): ax.set_yticklabels(["pos valid times", "raw data valid times", "epoch"]) ax.set_xlabel("Time [s]") ax.grid(True) + if return_fig: + return fig def intervals_by_length(interval_list, min_length=0.0, max_length=1e10): diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index ea661a29d..732c9779e 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -8,7 +8,6 @@ import pynwb.behavior from position_tools import ( get_angle, - get_centriod, get_distance, get_speed, get_velocity, @@ -30,6 +29,12 @@ from spyglass.utils import SpyglassMixin, logger from spyglass.utils.dj_helper_fn import deprecated_factory +try: + from position_tools import get_centroid +except ImportError: + logger.warning("Please update position_tools to >= 0.1.0") + from position_tools import get_centriod as get_centroid + schema = dj.schema("common_position") @@ -417,7 +422,7 @@ def calculate_position_info( ) # Calculate position, orientation, velocity, speed - position = get_centriod(back_LED, front_LED) # cm + position = get_centroid(back_LED, front_LED) # cm orientation = get_angle(back_LED, front_LED) # radians is_nan = np.isnan(orientation) diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index 6792453bc..f6f783262 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -63,13 +63,15 @@ def make(self, key): nwbf = get_nwb_file(nwb_file_abspath) config = get_config(nwb_file_abspath) - # certain data are not associated with a single NWB file / session because they may apply to - # multiple sessions. these data go into dj.Manual tables. - # e.g., a lab member may be associated with multiple experiments, so the lab member table should not - # be dependent on (contain a primary key for) a session. - - # here, we create new entries in these dj.Manual tables based on the values read from the NWB file - # then, they are linked to the session via fields of Session (e.g., Subject, Institution, Lab) or part + # certain data are not associated with a single NWB file / session + # because they may apply to multiple sessions. these data go into + # dj.Manual tables. e.g., a lab member may be associated with multiple + # experiments, so the lab member table should not be dependent on + # (contain a primary key for) a session. + + # here, we create new entries in these dj.Manual tables based on the + # values read from the NWB file then, they are linked to the session + # via fields of Session (e.g., Subject, Institution, Lab) or part # tables (e.g., Experimenter, DataAcquisitionDevice). logger.info("Institution...") @@ -221,17 +223,19 @@ def add_session_to_group( ) @staticmethod - def remove_session_from_group(nwb_file_name: str, session_group_name: str): + def remove_session_from_group( + nwb_file_name: str, session_group_name: str, *args, **kwargs + ): query = { "session_group_name": session_group_name, "nwb_file_name": nwb_file_name, } - (SessionGroupSession & query).delete() + (SessionGroupSession & query).delete(*args, **kwargs) @staticmethod - def delete_group(session_group_name: str): + def delete_group(session_group_name: str, *args, **kwargs): query = {"session_group_name": session_group_name} - (SessionGroup & query).delete() + (SessionGroup & query).delete(*args, **kwargs) @staticmethod def get_group_sessions(session_group_name: str): diff --git a/src/spyglass/data_import/__init__.py b/src/spyglass/data_import/__init__.py index 703cfa3c1..9c68cf038 100644 --- a/src/spyglass/data_import/__init__.py +++ b/src/spyglass/data_import/__init__.py @@ -1 +1,2 @@ +# TODO: change naming to avoid match between module and function from .insert_sessions import insert_sessions diff --git a/src/spyglass/data_import/insert_sessions.py b/src/spyglass/data_import/insert_sessions.py index c862fe85b..329a7be42 100644 --- a/src/spyglass/data_import/insert_sessions.py +++ b/src/spyglass/data_import/insert_sessions.py @@ -101,7 +101,7 @@ def copy_nwb_link_raw_ephys(nwb_file_name, out_nwb_file_name): if os.path.exists(out_nwb_file_abs_path): if debug_mode: return out_nwb_file_abs_path - warnings.warn( + logger.warning( f"Output file {out_nwb_file_abs_path} exists and will be " + "overwritten." ) diff --git a/src/spyglass/decoding/decoding_merge.py b/src/spyglass/decoding/decoding_merge.py index c49971c78..1752b1165 100644 --- a/src/spyglass/decoding/decoding_merge.py +++ b/src/spyglass/decoding/decoding_merge.py @@ -21,14 +21,14 @@ class DecodingOutput(_Merge, SpyglassMixin): source: varchar(32) """ - class ClusterlessDecodingV1(SpyglassMixin, dj.Part): + class ClusterlessDecodingV1(SpyglassMixin, dj.Part): # noqa: F811 definition = """ -> master --- -> ClusterlessDecodingV1 """ - class SortedSpikesDecodingV1(SpyglassMixin, dj.Part): + class SortedSpikesDecodingV1(SpyglassMixin, dj.Part): # noqa: F811 definition = """ -> master --- diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index 4672af615..e2e0a2142 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -30,7 +30,8 @@ def __init__(self, base_dir: str = None, **kwargs): self.supplied_base_dir = base_dir self._config = dict() self.config_defaults = dict(prepopulate=True) - self._debug_mode = False + self._debug_mode = kwargs.get("debug_mode", False) + self._test_mode = kwargs.get("test_mode", False) self._dlc_base = None self.relative_dirs = { @@ -106,6 +107,7 @@ def load_config(self, force_reload=False): dj_dlc = dj_custom.get("dlc_dirs", {}) self._debug_mode = dj_custom.get("debug_mode", False) + self._test_mode = dj_custom.get("test_mode", False) resolved_base = ( self.supplied_base_dir @@ -166,6 +168,7 @@ def load_config(self, force_reload=False): self._config = dict( debug_mode=self._debug_mode, + test_mode=self._test_mode, **self.config_defaults, **config_dirs, **kachery_zone_dict, @@ -381,6 +384,7 @@ def _dj_custom(self) -> dict: return { "custom": { "debug_mode": str(self.debug_mode).lower(), + "test_mode": str(self._test_mode).lower(), "spyglass_dirs": { "base": self.base_dir, "raw": self.raw_dir, @@ -453,8 +457,19 @@ def video_dir(self) -> str: @property def debug_mode(self) -> bool: + """Returns True if debug_mode is set. + + Supports skipping inserts for Dockerized development. + """ return self._debug_mode + @property + def test_mode(self) -> bool: + """Returns True if test_mode is set. + + Required for pytests to run without prompts.""" + return self._test_mode + @property def dlc_project_dir(self) -> str: return self.config.get(self.dir_to_var("project", "dlc")) @@ -479,6 +494,7 @@ def dlc_output_dir(self) -> str: waveform_dir = sg_config.waveform_dir video_dir = sg_config.video_dir debug_mode = sg_config.debug_mode +test_mode = sg_config.test_mode prepopulate = config.get("prepopulate", False) dlc_project_dir = sg_config.dlc_project_dir dlc_video_dir = sg_config.dlc_video_dir diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index eddd77652..5c900b66c 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -758,7 +758,7 @@ def delete_downstream_merge( def _warn_on_restriction(table: dj.Table, restriction: str = None): """Warn if restriction on table object differs from input restriction""" - if restriction is None and table().restriction: + if restriction is None and table.restriction: logger.warn( f"Warning: ignoring table restriction: {table().restriction}.\n\t" + "Please pass restrictions as an arg" diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 8a53743de..3ee0f6292 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -226,7 +226,13 @@ def _check_delete_permission(self) -> None: user_name = LabMember().get_djuser_name(dj_user) for experimenter in set(experimenters): - if user_name not in LabTeam().get_team_members(experimenter): + # Check once with cache, if fails, reload and check again + # On eval as set, reload will only be called once + if user_name not in LabTeam().get_team_members( + experimenter + ) and user_name not in LabTeam().get_team_members( + experimenter, reload=True + ): sess_w_exp = sess_summary & {self._member_pk: experimenter} raise PermissionError( f"User '{user_name}' is not on a team with '{experimenter}'" @@ -259,7 +265,7 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): merge_deletes = self._merge_del_func( self, - restriction=self.restriction, + restriction=self.restriction if self.restriction else None, dry_run=True, disable_warning=True, ) diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 000000000..476dbb4c8 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,47 @@ +# PyTests + +This directory is contains files for testing the code. Simply by running +`pytest` from the root directory, all tests will be run with default parameters +specified in `pyproject.toml`. Notable optional parameters include... + +- Coverage items. The coverage report indicates what percentage of the code was + included in tests. + + - `--cov=spyglatss`: Which package should be described in the coverage report + - `--cov-report term-missing`: Include lines of items missing in coverage + +- Verbosity. + + - `-v`: List individual tests, report pass/fail + - `--quiet-spy`: Default False. When True, print and other logging statements + from Spyglass are silenced. + +- Data and database. + + - `--no-server`: Default False, launch Docker container from python. When + True, no server is started and tests attempt to connect to existing + container. + - `--no-teardown`: Default False. When True, docker database tables are + preserved on exit. Set to false to inspect output items after testing. + - `--my-datadir ./rel-path/`: Default `./tests/test_data/`. Where to store + created files. + +- Incremental running. + + - `-m`: Run tests with the + [given marker](https://docs.pytest.org/en/6.2.x/usage.html#specifying-tests-selecting-tests) + (e.g., `pytest -m current`). + - `--sw`: Stepwise. Continue from previously failed test when starting again. + - `-s`: No capture. By including `from IPython import embed; embed()` in a + test, and using this flag, you can open an IPython environment from within + a test + - `--pdb`: Enter debug mode if a test fails. + - `tests/test_file.py -k test_name`: To run just a set of tests, specify the + file name at the end of the command. To run a single test, further specify + `-k` with the test name. + +When customizing parameters, comment out the `addopts` line in `pyproject.toml`. + +```console +pytest -m current --quiet-spy --no-teardown tests/test_file.py -k test_name +``` diff --git a/tests/ci_config.py b/tests/ci_config.py deleted file mode 100644 index e329df7ed..000000000 --- a/tests/ci_config.py +++ /dev/null @@ -1,27 +0,0 @@ -import os -from pathlib import Path - -import datajoint as dj - -# NOTE this env var is set in the GitHub Action directly -data_dir = Path(os.environ["SPYGLASS_BASE_DIR"]) - -raw_dir = data_dir / "raw" -analysis_dir = data_dir / "analysis" - -dj.config["database.host"] = "localhost" -dj.config["database.user"] = "root" -dj.config["database.password"] = "tutorial" -dj.config["stores"] = { - "raw": { - "protocol": "file", - "location": str(raw_dir), - "stage": str(raw_dir), - }, - "analysis": { - "protocol": "file", - "location": str(analysis_dir), - "stage": str(analysis_dir), - }, -} -dj.config.save_global() diff --git a/tests/datajoint/__init__.py b/tests/common/__init__.py similarity index 100% rename from tests/datajoint/__init__.py rename to tests/common/__init__.py diff --git a/tests/common/conftest.py b/tests/common/conftest.py new file mode 100644 index 000000000..41fdea95a --- /dev/null +++ b/tests/common/conftest.py @@ -0,0 +1,48 @@ +import pytest + + +@pytest.fixture(scope="session") +def mini_devices(mini_content): + yield mini_content.devices + + +@pytest.fixture(scope="session") +def mini_behavior(mini_content): + yield mini_content.processing.get("behavior") + + +@pytest.fixture(scope="session") +def mini_pos(mini_behavior): + yield mini_behavior.get_data_interface("position").spatial_series + + +@pytest.fixture(scope="session") +def mini_pos_series(mini_pos): + yield next(iter(mini_pos)) + + +@pytest.fixture(scope="session") +def mini_pos_interval_dict(common): + yield {"interval_list_name": common.PositionSource.get_pos_interval_name(0)} + + +@pytest.fixture(scope="session") +def mini_pos_tbl(common, mini_pos_series): + yield common.PositionSource.SpatialSeries * common.RawPosition.PosObject & { + "name": mini_pos_series + } + + +@pytest.fixture(scope="session") +def pos_src(common): + yield common.PositionSource() + + +@pytest.fixture(scope="session") +def pos_interval_01(pos_src): + yield [pos_src.get_pos_interval_name(x) for x in range(1)] + + +@pytest.fixture(scope="session") +def common_ephys(common): + yield common.common_ephys diff --git a/tests/common/test_behav.py b/tests/common/test_behav.py new file mode 100644 index 000000000..c21ed96f6 --- /dev/null +++ b/tests/common/test_behav.py @@ -0,0 +1,73 @@ +import pytest +from pandas import DataFrame + + +def test_invalid_interval(pos_src): + """Test invalid interval""" + with pytest.raises(ValueError): + pos_src.get_pos_interval_name("invalid_interval") + + +def test_invalid_epoch_num(common): + """Test invalid epoch num""" + with pytest.raises(ValueError): + common.PositionSource.get_epoch_num("invalid_epoch_num") + + +def test_raw_position_fetchnwb(common, mini_pos, mini_pos_interval_dict): + """Test RawPosition fetch nwb""" + fetched = DataFrame( + (common.RawPosition & mini_pos_interval_dict) + .fetch_nwb()[0]["raw_position"] + .data + ) + raw = DataFrame(mini_pos["led_0_series_0"].data) + # compare with mini_pos + assert fetched.equals(raw), "RawPosition fetch_nwb failed" + + +@pytest.mark.skip(reason="No video files in mini") +def test_videofile_no_transaction(common, mini_restr): + """Test no transaction""" + common.VideoFile()._no_transaction_make(mini_restr) + + +@pytest.mark.skip(reason="No video files in mini") +def test_videofile_update_entries(common): + """Test update entries""" + common.VideoFile().update_entries() + + +@pytest.mark.skip(reason="No video files in mini") +def test_videofile_getabspath(common, mini_restr): + """Test get absolute path""" + common.VideoFile().getabspath(mini_restr) + + +def test_posinterval_no_transaction(verbose_context, common, mini_restr): + """Test no transaction""" + before = common.PositionIntervalMap().fetch() + with verbose_context: + common.PositionIntervalMap()._no_transaction_make(mini_restr) + after = common.PositionIntervalMap().fetch() + assert ( + len(after) == len(before) + 2 + ), "PositionIntervalMap no_transaction had unexpected effect" + + +def test_get_pos_interval_name(pos_src, pos_interval_01): + """Test get pos interval name""" + names = [f"pos {x} valid times" for x in range(1)] + assert pos_interval_01 == names, "get_pos_interval_name failed" + + +def test_convert_epoch(common, mini_dict, pos_interval_01): + this_key = ( + common.IntervalList & mini_dict & {"interval_list_name": "01_s1"} + ).fetch1() + ret = common.common_behav.convert_epoch_interval_name_to_position_interval_name( + this_key + ) + assert ( + ret == pos_interval_01[0] + ), "convert_epoch_interval_name_to_position_interval_name failed" diff --git a/tests/common/test_common_interval.py b/tests/common/test_common_interval.py deleted file mode 100644 index 293abda91..000000000 --- a/tests/common/test_common_interval.py +++ /dev/null @@ -1,62 +0,0 @@ -import numpy as np -from spyglass.common.common_interval import ( - interval_list_intersect, - interval_set_difference_inds, -) - - -def test_interval_list_intersect1(): - interval_list1 = np.array([[0, 10], [3, 5], [14, 16]]) - interval_list2 = np.array([[10, 11], [9, 14], [13, 18]]) - intersection_list = interval_list_intersect(interval_list1, interval_list2) - assert np.all(intersection_list == np.array([[9, 10], [14, 16]])) - - -def test_interval_list_intersect2(): - # if there is no intersection, return empty list - interval_list1 = np.array([[0, 10], [3, 5]]) - interval_list2 = np.array([[11, 14]]) - intersection_list = interval_list_intersect(interval_list1, interval_list2) - assert len(intersection_list) == 0 - - -def test_interval_set_difference_inds_no_overlap(): - intervals1 = [(0, 5), (8, 10)] - intervals2 = [(5, 8)] - result = interval_set_difference_inds(intervals1, intervals2) - assert result == [(0, 5), (8, 10)] - - -def test_interval_set_difference_inds_overlap(): - intervals1 = [(0, 5), (8, 10)] - intervals2 = [(1, 2), (3, 4), (6, 9)] - result = interval_set_difference_inds(intervals1, intervals2) - assert result == [(0, 1), (2, 3), (4, 5), (9, 10)] - - -def test_interval_set_difference_inds_empty_intervals1(): - intervals1 = [] - intervals2 = [(1, 2), (3, 4), (6, 9)] - result = interval_set_difference_inds(intervals1, intervals2) - assert result == [] - - -def test_interval_set_difference_inds_empty_intervals2(): - intervals1 = [(0, 5), (8, 10)] - intervals2 = [] - result = interval_set_difference_inds(intervals1, intervals2) - assert result == [(0, 5), (8, 10)] - - -def test_interval_set_difference_inds_equal_intervals(): - intervals1 = [(0, 5), (8, 10)] - intervals2 = [(0, 5), (8, 10)] - result = interval_set_difference_inds(intervals1, intervals2) - assert result == [] - - -def test_interval_set_difference_inds_multiple_overlaps(): - intervals1 = [(0, 10)] - intervals2 = [(1, 3), (4, 6), (7, 9)] - result = interval_set_difference_inds(intervals1, intervals2) - assert result == [(0, 1), (3, 4), (6, 7), (9, 10)] diff --git a/tests/common/test_device.py b/tests/common/test_device.py new file mode 100644 index 000000000..84323f2df --- /dev/null +++ b/tests/common/test_device.py @@ -0,0 +1,40 @@ +import pytest +from numpy import array_equal + + +def test_invalid_device(common, populate_exception): + device_dict = common.DataAcquisitionDevice.fetch(as_dict=True)[0] + device_dict["other"] = "invalid" + with pytest.raises(populate_exception): + common.DataAcquisitionDevice._add_device(device_dict) + + +def test_spikegadets_system_alias(mini_insert, common): + assert ( + common.DataAcquisitionDevice()._add_system("MCU") == "SpikeGadgets" + ), "SpikeGadgets MCU alias not found" + + +def test_invalid_probe(common, populate_exception): + probe_dict = common.ProbeType.fetch(as_dict=True)[0] + probe_dict["other"] = "invalid" + with pytest.raises(populate_exception): + common.Probe._add_probe_type(probe_dict) + + +def test_create_probe(common, mini_devices, mini_path, mini_copy_name): + probe_id = common.Probe.fetch("KEY", as_dict=True)[0] + probe_type = common.ProbeType.fetch("KEY", as_dict=True)[0] + before = common.Probe.fetch() + common.Probe.create_from_nwbfile( + nwb_file_name=mini_copy_name, + nwb_device_name="probe 0", + contact_side_numbering=False, + **probe_id, + **probe_type, + ) + after = common.Probe.fetch() + # Because already inserted, expect no change + assert array_equal( + before, after + ), "Probe create_from_nwbfile had unexpected effect" diff --git a/tests/common/test_dio.py b/tests/common/test_dio.py new file mode 100644 index 000000000..f4b258dde --- /dev/null +++ b/tests/common/test_dio.py @@ -0,0 +1,31 @@ +import pytest +from numpy import allclose, array + + +@pytest.fixture(scope="session") +def dio_events(common): + yield common.common_dio.DIOEvents + + +@pytest.fixture(scope="session") +def dio_fig(mini_insert, dio_events, mini_restr): + yield (dio_events & mini_restr).plot_all_dio_events(return_fig=True) + + +def test_plot_dio_axes(dio_fig, dio_events): + """Check that all events are plotted.""" + events_fig = set(x.yaxis.get_label().get_text() for x in dio_fig.get_axes()) + events_fetch = set(dio_events.fetch("dio_event_name")) + assert events_fig == events_fetch, "Mismatch in events plotted." + + +def test_plot_dio_data(common, dio_fig): + """Hash summary of figure object.""" + data_fig = dio_fig.get_axes()[0].lines[0].get_xdata() + data_block = ( + common.IntervalList & 'interval_list_name LIKE "raw%"' + ).fetch1("valid_times") + data_fetch = array((data_block[0][0], data_block[-1][1])) + assert allclose( + data_fig, data_fetch, atol=1e-8 + ), "Mismatch in data plotted." diff --git a/tests/common/test_ephys.py b/tests/common/test_ephys.py new file mode 100644 index 000000000..9ad1ea0a4 --- /dev/null +++ b/tests/common/test_ephys.py @@ -0,0 +1,33 @@ +import pytest +from numpy import array_equal + + +def test_create_from_config(mini_insert, common_ephys, mini_path): + before = common_ephys.Electrode().fetch() + common_ephys.Electrode.create_from_config(mini_path.stem) + after = common_ephys.Electrode().fetch() + # Because already inserted, expect no change + assert array_equal( + before, after + ), "Electrode.create_from_config had unexpected effect" + + +def test_raw_object(mini_insert, common_ephys, mini_dict, mini_content): + obj_fetch = common_ephys.Raw().nwb_object(mini_dict).object_id + obj_raw = mini_content.get_acquisition().object_id + assert obj_fetch == obj_raw, "Raw.nwb_object did not return expected object" + + +def test_set_lfp_electrodes(mini_insert, common_ephys, mini_copy_name): + before = common_ephys.LFPSelection().fetch() + common_ephys.LFPSelection().set_lfp_electrodes(mini_copy_name, [0]) + after = common_ephys.LFPSelection().fetch() + # Because already inserted, expect no change + assert ( + len(after) == len(before) + 1 + ), "Set LFP electrodes had unexpected effect" + + +@pytest.mark.skip(reason="Not testing V0: common lfp") +def test_lfp(): + pass diff --git a/tests/common/test_filter.py b/tests/common/test_filter.py new file mode 100644 index 000000000..9e0be584f --- /dev/null +++ b/tests/common/test_filter.py @@ -0,0 +1,79 @@ +import pytest + + +@pytest.fixture(scope="session") +def filter_parameters(common): + yield common.FirFilterParameters() + + +@pytest.fixture(scope="session") +def filter_dict(filter_parameters): + yield {"filter_name": "test", "fs": 10} + + +@pytest.fixture(scope="session") +def add_filter(filter_parameters, filter_dict): + filter_parameters.add_filter( + **filter_dict, filter_type="lowpass", band_edges=[1, 2] + ) + + +@pytest.fixture(scope="session") +def filter_coeff(filter_parameters, filter_dict): + yield filter_parameters._filter_restrict(**filter_dict)["filter_coeff"] + + +def test_add_filter(filter_parameters, add_filter, filter_dict): + """Test add filter""" + assert filter_parameters & filter_dict, "add_filter failed" + + +def test_filter_restrict( + filter_parameters, add_filter, filter_dict, filter_coeff +): + assert sum(filter_coeff) == pytest.approx( + 0.999134, abs=1e-6 + ), "filter_restrict failed" + + +def test_plot_magitude(filter_parameters, add_filter, filter_dict): + fig = filter_parameters.plot_magnitude(**filter_dict, return_fig=True) + assert sum(fig.get_axes()[0].lines[0].get_xdata()) == pytest.approx( + 163837.5, abs=1 + ), "plot_magnitude failed" + + +def test_plot_fir_filter( + filter_parameters, add_filter, filter_dict, filter_coeff +): + fig = filter_parameters.plot_fir_filter(**filter_dict, return_fig=True) + assert sum(fig.get_axes()[0].lines[0].get_ydata()) == sum( + filter_coeff + ), "Plot filter failed" + + +def test_filter_delay(filter_parameters, add_filter, filter_dict): + delay = filter_parameters.filter_delay(**filter_dict) + assert delay == 27, "filter_delay failed" + + +def test_time_bound_warning(filter_parameters, add_filter, filter_dict): + with pytest.warns(UserWarning): + filter_parameters._time_bound_check(1, 3, [2, 5], 4) + + +@pytest.mark.skip(reason="Not testing V0: filter_data") +def test_filter_data(filter_parameters, mini_content): + pass + + +def test_calc_filter_delay(filter_parameters, filter_coeff): + delay = filter_parameters.calc_filter_delay(filter_coeff) + assert delay == 27, "filter_delay failed" + + +def test_create_standard_filters(filter_parameters): + filter_parameters.create_standard_filters() + assert filter_parameters & { + "filter_name": "LFP 0-400 Hz" + }, "create_standard_filters failed" diff --git a/tests/common/test_insert.py b/tests/common/test_insert.py new file mode 100644 index 000000000..6d2fd18b3 --- /dev/null +++ b/tests/common/test_insert.py @@ -0,0 +1,220 @@ +from datajoint.hash import key_hash +from pandas import DataFrame, Index +from pytest import approx + + +def test_insert_session(mini_insert, mini_content, mini_restr, common): + subj_raw = mini_content.subject + meta_raw = mini_content + + sess_data = (common.Session & mini_restr).fetch1() + assert ( + sess_data["subject_id"] == subj_raw.subject_id + ), "Subjuect ID not match" + + attrs = [ + ("institution_name", "institution"), + ("lab_name", "lab"), + ("session_id", "session_id"), + ("session_description", "session_description"), + ("experiment_description", "experiment_description"), + ] + + for sess_attr, meta_attr in attrs: + assert sess_data[sess_attr] == getattr( + meta_raw, meta_attr + ), f"Session table {sess_attr} not match raw data {meta_attr}" + + time_attrs = [ + ("session_start_time", "session_start_time"), + ("timestamps_reference_time", "timestamps_reference_time"), + ] + for sess_attr, meta_attr in time_attrs: + # a. strip timezone info from meta_raw + # b. convert to timestamp + # c. compare precision to 1 second + assert sess_data[sess_attr].timestamp() == approx( + getattr(meta_raw, meta_attr).replace(tzinfo=None).timestamp(), abs=1 + ), f"Session table {sess_attr} not match raw data {meta_attr}" + + +def test_insert_electrode_group(mini_insert, mini_content, common): + group_name = "0" + egroup_data = ( + common.ElectrodeGroup & {"electrode_group_name": group_name} + ).fetch1() + egroup_raw = mini_content.electrode_groups.get(group_name) + + assert ( + egroup_data["description"] == egroup_raw.description + ), "ElectrodeGroup description not match" + + assert egroup_data["region_id"] == ( + common.BrainRegion & {"region_name": egroup_raw.location} + ).fetch1( + "region_id" + ), "Region ID does not match across raw data and BrainRegion table" + + +def test_insert_electrode(mini_insert, mini_content, mini_restr, common): + electrode_id = "0" + e_data = (common.Electrode & {"electrode_id": electrode_id}).fetch1() + e_raw = mini_content.electrodes.get(int(electrode_id)).to_dict().copy() + + attrs = [ + ("x", "x"), + ("y", "y"), + ("z", "z"), + ("impedance", "imp"), + ("filtering", "filtering"), + ("original_reference_electrode", "ref_elect_id"), + ] + + for e_attr, meta_attr in attrs: + assert ( + e_data[e_attr] == e_raw[meta_attr][int(electrode_id)] + ), f"Electrode table {e_attr} not match raw data {meta_attr}" + + +def test_insert_raw(mini_insert, mini_content, mini_restr, common): + raw_data = (common.Raw & mini_restr).fetch1() + raw_raw = mini_content.get_acquisition() + + attrs = [ + ("comments", "comments"), + ("description", "description"), + ] + for raw_attr, meta_attr in attrs: + assert raw_data[raw_attr] == getattr( + raw_raw, meta_attr + ), f"Raw table {raw_attr} not match raw data {meta_attr}" + + +def test_insert_sample_count(mini_insert, mini_content, mini_restr, common): + sample_data = (common.SampleCount & mini_restr).fetch1() + sample_full = mini_content.processing.get("sample_count") + if not sample_full: + assert False, "No sample count data in raw data" + sample_raw = sample_full.data_interfaces.get("sample_count") + assert ( + sample_data["sample_count_object_id"] == sample_raw.object_id + ), "SampleCount insertion error" + + +def test_insert_dio(mini_insert, mini_behavior, mini_restr, common): + events_data = (common.DIOEvents & mini_restr).fetch(as_dict=True) + events_raw = mini_behavior.get_data_interface( + "behavioral_events" + ).time_series + + assert len(events_data) == len(events_raw), "Number of events not match" + + event = [p for p in events_raw.keys() if "Poke" in p][0] + event_raw = events_raw.get(event) + # event_data = (common.DIOEvents & {"dio_event_name": event}).fetch(as_dict=True)[0] + event_data = (common.DIOEvents & {"dio_event_name": event}).fetch1() + + assert ( + event_data["dio_object_id"] == event_raw.object_id + ), "DIO Event insertion error" + + +def test_insert_pos( + mini_insert, + common, + mini_behavior, + mini_restr, + mini_pos_series, + mini_pos_tbl, +): + pos_data = (common.PositionSource.SpatialSeries & mini_restr).fetch() + pos_raw = mini_behavior.get_data_interface("position").spatial_series + + assert len(pos_data) == len(pos_raw), "Number of spatial series not match" + + raw_obj_id = pos_raw[mini_pos_series].object_id + data_obj_id = mini_pos_tbl.fetch1("raw_position_object_id") + + assert data_obj_id == raw_obj_id, "PosObject insertion error" + + +def test_fetch_posobj( + mini_insert, common, mini_pos, mini_pos_series, mini_pos_tbl +): + pos_key = ( + common.PositionSource.SpatialSeries & mini_pos_tbl.fetch("KEY") + ).fetch(as_dict=True)[0] + pos_df = (common.RawPosition & pos_key).fetch1_dataframe().iloc[:, 0:2] + + series = mini_pos[mini_pos_series] + raw_df = DataFrame( + data=series.data, + index=Index(series.timestamps, name="time"), + columns=[col + "1" for col in series.description.split(", ")], + ) + assert key_hash(pos_df) == key_hash(raw_df), "Spatial series fetch error" + + +def test_insert_device(mini_insert, mini_devices, common): + this_device = "dataacq_device0" + device_raw = mini_devices.get(this_device) + device_data = ( + common.DataAcquisitionDevice + & {"data_acquisition_device_name": this_device} + ).fetch1() + + attrs = [ + ("data_acquisition_device_name", "name"), + ("data_acquisition_device_system", "system"), + ("data_acquisition_device_amplifier", "amplifier"), + ("adc_circuit", "adc_circuit"), + ] + + for device_attr, meta_attr in attrs: + assert device_data[device_attr] == getattr( + device_raw, meta_attr + ), f"Device table {device_attr} not match raw data {meta_attr}" + + +def test_insert_camera(mini_insert, mini_devices, common): + camera_raw = mini_devices.get("camera_device 0") + camera_data = ( + common.CameraDevice & {"camera_name": camera_raw.camera_name} + ).fetch1() + + attrs = [ + ("camera_name", "camera_name"), + ("manufacturer", "manufacturer"), + ("model", "model"), + ("lens", "lens"), + ("meters_per_pixel", "meters_per_pixel"), + ] + for camera_attr, meta_attr in attrs: + assert camera_data[camera_attr] == getattr( + camera_raw, meta_attr + ), f"Camera table {camera_attr} not match raw data {meta_attr}" + + +def test_insert_probe(mini_insert, mini_devices, common): + this_probe = "probe 0" + probe_raw = mini_devices.get(this_probe) + probe_id = probe_raw.probe_type + + probe_data = ( + common.Probe * common.ProbeType & {"probe_id": probe_id} + ).fetch1() + + attrs = [ + ("probe_type", "probe_type"), + ("probe_description", "probe_description"), + ("contact_side_numbering", "contact_side_numbering"), + ] + + for probe_attr, meta_attr in attrs: + assert probe_data[probe_attr] == str( + getattr(probe_raw, meta_attr) + ), f"Probe table {probe_attr} not match raw data {meta_attr}" + + assert probe_data["num_shanks"] == len( + probe_raw.shanks + ), "Number of shanks in ProbeType number not raw data" diff --git a/tests/common/test_interval.py b/tests/common/test_interval.py new file mode 100644 index 000000000..8353961f8 --- /dev/null +++ b/tests/common/test_interval.py @@ -0,0 +1,27 @@ +import pytest +from numpy import array_equal + + +@pytest.fixture(scope="session") +def interval_list(common): + yield common.IntervalList() + + +def test_plot_intervals(mini_insert, interval_list): + fig = interval_list.plot_intervals(return_fig=True) + interval_list_name = fig.get_axes()[0].get_yticklabels()[0].get_text() + times_fetch = ( + interval_list & {"interval_list_name": interval_list_name} + ).fetch1("valid_times")[0] + times_plot = fig.get_axes()[0].lines[0].get_xdata() + + assert array_equal(times_fetch, times_plot), "plot_intervals failed" + + +def test_plot_epoch(mini_insert, interval_list): + fig = interval_list.plot_epoch_pos_raw_intervals(return_fig=True) + epoch_label = fig.get_axes()[0].get_yticklabels()[-1].get_text() + assert epoch_label == "epoch", "plot_epoch failed" + + epoch_interv = fig.get_axes()[0].lines[0].get_ydata() + assert array_equal(epoch_interv, [1, 1]), "plot_epoch failed" diff --git a/tests/common/test_interval_helpers.py b/tests/common/test_interval_helpers.py new file mode 100644 index 000000000..d4e7eb1ac --- /dev/null +++ b/tests/common/test_interval_helpers.py @@ -0,0 +1,272 @@ +import numpy as np +import pytest + + +@pytest.fixture(scope="session") +def list_intersect(common): + yield common.common_interval.interval_list_intersect + + +@pytest.mark.parametrize( + "one, two, result", + [ + ( + np.array([[0, 10], [3, 5], [14, 16]]), + np.array([[10, 11], [9, 14], [13, 18]]), + np.array([[9, 10], [14, 16]]), + ), + ( # Empty result for no intersection + np.array([[0, 10], [3, 5]]), + np.array([[11, 14]]), + np.array([]), + ), + ], +) +def test_list_intersect(list_intersect, one, two, result): + assert np.array_equal( + list_intersect(one, two), result + ), "Problem with common_interval.interval_list_intersect" + + +@pytest.fixture(scope="session") +def set_difference(common): + yield common.common_interval.interval_set_difference_inds + + +@pytest.mark.parametrize( + "one, two, expected_result", + [ + ( # No overlap + [(0, 5), (8, 10)], + [(5, 8)], + [(0, 5), (8, 10)], + ), + ( # Overlap + [(0, 5), (8, 10)], + [(1, 2), (3, 4), (6, 9)], + [(0, 1), (2, 3), (4, 5), (9, 10)], + ), + ( # One empty + [], + [(1, 2), (3, 4), (6, 9)], + [], + ), + ( # Two empty + [(0, 5), (8, 10)], + [], + [(0, 5), (8, 10)], + ), + ( # Equal intervals + [(0, 5), (8, 10)], + [(0, 5), (8, 10)], + [], + ), + ( # Multiple overlaps + [(0, 10)], + [(1, 3), (4, 6), (7, 9)], + [(0, 1), (3, 4), (6, 7), (9, 10)], + ), + ], +) +def test_set_difference(set_difference, one, two, expected_result): + assert ( + set_difference(one, two) == expected_result + ), "Problem with common_interval.interval_set_difference_inds" + + +@pytest.mark.parametrize( + "expected_result, min_len, max_len", + [ + (np.array([[0, 1]]), 0.0, 10), + (np.array([[0, 1], [0, 1e11]]), 0.0, 1e12), + (np.array([[0, 0], [0, 1]]), -1, 10), + ], +) +def test_intervals_by_length(common, expected_result, min_len, max_len): + # input is the same across all tests. Could be parametrized as above + inds = common.common_interval.intervals_by_length( + interval_list=np.array([[0, 0], [0, 1], [0, 1e11]]), + min_length=min_len, + max_length=max_len, + ) + assert np.array_equal( + inds, expected_result + ), "Problem with common_interval.intervals_by_length" + + +@pytest.fixture +def interval_list_dict(): + yield { + "interval_list": np.array([[1, 4], [6, 8]]), + "timestamps": np.array([0, 1, 5, 7, 8, 9]), + } + + +def test_interval_list_contains_ind(common, interval_list_dict): + idxs = common.common_interval.interval_list_contains_ind( + **interval_list_dict + ) + assert np.array_equal( + idxs, np.array([1, 3, 4]) + ), "Problem with common_interval.interval_list_contains_ind" + + +def test_insterval_list_contains(common, interval_list_dict): + idxs = common.common_interval.interval_list_contains(**interval_list_dict) + assert np.array_equal( + idxs, np.array([1, 7, 8]) + ), "Problem with common_interval.interval_list_contains" + + +def test_interval_list_excludes_ind(common, interval_list_dict): + idxs = common.common_interval.interval_list_excludes_ind( + **interval_list_dict + ) + assert np.array_equal( + idxs, np.array([0, 2, 5]) + ), "Problem with common_interval.interval_list_excludes_ind" + + +def test_interval_list_excludes(common, interval_list_dict): + idxs = common.common_interval.interval_list_excludes(**interval_list_dict) + assert np.array_equal( + idxs, np.array([0, 5, 9]) + ), "Problem with common_interval.interval_list_excludes" + + +def test_consolidate_intervals_1dim(common): + exp = common.common_interval.consolidate_intervals(np.array([0, 1])) + assert np.array_equal( + exp, np.array([[0, 1]]) + ), "Problem with common_interval.consolidate_intervals" + + +@pytest.mark.parametrize( + "interval1, interval2, exp_result", + [ + ( + np.array([[0, 1]]), + np.array([[2, 3]]), + np.array([[0, 3]]), + ), + ( + np.array([[2, 3]]), + np.array([[0, 1]]), + np.array([[0, 3]]), + ), + ( + np.array([[0, 3]]), + np.array([[2, 4]]), + np.array([[0, 3], [2, 4]]), + ), + ], +) +def test_union_adjacent_index(common, interval1, interval2, exp_result): + assert np.array_equal( + common.common_interval.union_adjacent_index(interval1, interval2), + exp_result, + ), "Problem with common_interval.union_adjacent_index" + + +@pytest.mark.parametrize( + "interval1, interval2, exp_result", + [ + ( + np.array([[0, 3]]), + np.array([[2, 4]]), + np.array([[0, 4]]), + ), + ( + np.array([[0, -1]]), + np.array([[2, 4]]), + np.array([[2, 0]]), + ), + ( + np.array([[0, 1]]), + np.array([[2, 1e11]]), + np.array([[0, 1], [2, 1e11]]), + ), + ], +) +def test_interval_list_union(common, interval1, interval2, exp_result): + assert np.array_equal( + common.common_interval.interval_list_union(interval1, interval2), + exp_result, + ), "Problem with common_interval.interval_list_union" + + +def test_interval_list_censor_error(common): + with pytest.raises(ValueError): + common.common_interval.interval_list_censor( + np.array([[0, 1]]), np.array([2]) + ) + + +def test_interval_list_censor(common): + assert np.array_equal( + common.common_interval.interval_list_censor( + np.array([[0, 2], [4, 5]]), np.array([1, 2, 4]) + ), + np.array([[1, 2]]), + ), "Problem with common_interval.interval_list_censor" + + +@pytest.mark.parametrize( + "interval_list, exp_result", + [ + ( + np.array([0, 1, 2, 3, 6, 7, 8, 9]), + np.array([[0, 3], [6, 9]]), + ), + ( + np.array([0, 1, 2]), + np.array([[0, 2]]), + ), + ( + np.array([2, 3, 1, 0]), + np.array([[0, 3]]), + ), + ( + np.array([2, 3, 0]), + np.array([[0, 0], [2, 3]]), + ), + ], +) +def test_interval_from_inds(common, interval_list, exp_result): + assert np.array_equal( + common.common_interval.interval_from_inds(interval_list), + exp_result, + ), "Problem with common_interval.interval_from_inds" + + +@pytest.mark.parametrize( + "intervals1, intervals2, min_length, exp_result", + [ + ( + np.array([[0, 2], [4, 5]]), + np.array([[1, 3], [2, 4]]), + 0, + np.array([[0, 1], [4, 5]]), + ), + ( + np.array([[0, 2], [4, 5]]), + np.array([[1, 3], [2, 4]]), + 1, + np.zeros((0, 2)), + ), + ( + np.array([[0, 2], [4, 6]]), + np.array([[5, 8], [2, 4]]), + 1, + np.array([[0, 2]]), + ), + ], +) +def test_interval_list_complement( + common, intervals1, intervals2, min_length, exp_result +): + ic = common.common_interval.interval_list_complement + assert np.array_equal( + ic(intervals1, intervals2, min_length), + exp_result, + ), "Problem with common_interval.interval_list_compliment" diff --git a/tests/common/test_lab.py b/tests/common/test_lab.py new file mode 100644 index 000000000..83ab84c10 --- /dev/null +++ b/tests/common/test_lab.py @@ -0,0 +1,110 @@ +import pytest +from numpy import array_equal + + +@pytest.fixture +def common_lab(common): + yield common.common_lab + + +@pytest.fixture +def add_admin(common_lab): + common_lab.LabMember.insert1( + dict( + lab_member_name="This Admin", + first_name="This", + last_name="Admin", + ), + skip_duplicates=True, + ) + common_lab.LabMember.LabMemberInfo.insert1( + dict( + lab_member_name="This Admin", + google_user_name="This Admin", + datajoint_user_name="this_admin", + admin=1, + ), + skip_duplicates=True, + ) + yield + + +@pytest.fixture +def add_member_team(common_lab, add_admin): + common_lab.LabMember.insert( + [ + dict( + lab_member_name="This Basic", + first_name="This", + last_name="Basic", + ), + dict( + lab_member_name="This Loner", + first_name="This", + last_name="Loner", + ), + ], + skip_duplicates=True, + ) + common_lab.LabMember.LabMemberInfo.insert( + [ + dict( + lab_member_name="This Basic", + google_user_name="This Basic", + datajoint_user_name="this_basic", + admin=0, + ), + dict( + lab_member_name="This Loner", + google_user_name="This Loner", + datajoint_user_name="this_loner", + admin=0, + ), + ], + skip_duplicates=True, + ) + common_lab.LabTeam.create_new_team( + team_name="This Team", + team_members=["This Admin", "This Basic"], + team_description="This Team Description", + ) + yield + + +def test_labmember_insert_file_str(mini_insert, common_lab, mini_copy_name): + before = common_lab.LabMember.fetch() + common_lab.LabMember.insert_from_nwbfile(mini_copy_name) + after = common_lab.LabMember.fetch() + # Already inserted, test func raises no error + assert array_equal(before, after), "LabMember not inserted correctly" + + +def test_fetch_admin(common_lab, add_admin): + assert ( + "this_admin" in common_lab.LabMember().admin + ), "LabMember admin not fetched correctly" + + +def test_get_djuser(common_lab, add_admin): + assert "This Admin" == common_lab.LabMember().get_djuser_name( + "this_admin" + ), "LabMember get_djuser not fetched correctly" + + +def test_get_djuser_error(common_lab, add_admin): + with pytest.raises(ValueError): + common_lab.LabMember().get_djuser_name("This Admin2") + + +def test_get_team_members(common_lab, add_member_team): + assert common_lab.LabTeam().get_team_members("This Admin") == set( + ("This Admin", "This Basic") + ), "LabTeam get_team_members not fetched correctly" + + +def test_decompose_name_error(common_lab): + # NOTE: Should change with solve of #304 + with pytest.raises(ValueError): + common_lab.decompose_name("This Invalid Name") + with pytest.raises(ValueError): + common_lab.decompose_name("This, Invalid, Name") diff --git a/tests/common/test_nwbfile.py b/tests/common/test_nwbfile.py new file mode 100644 index 000000000..a8671b7ce --- /dev/null +++ b/tests/common/test_nwbfile.py @@ -0,0 +1,41 @@ +import os + +import pytest + + +@pytest.fixture +def common_nwbfile(common): + """Return a common NWBFile object.""" + return common.common_nwbfile + + +@pytest.fixture +def lockfile(base_dir, teardown): + lockfile = base_dir / "temp.lock" + lockfile.touch() + os.environ["NWB_LOCK_FILE"] = str(lockfile) + yield lockfile + if teardown: + os.remove(lockfile) + + +def test_get_file_name_error(common_nwbfile): + """Test that an error is raised when trying non-existent file.""" + with pytest.raises(ValueError): + common_nwbfile.Nwbfile._get_file_name("non-existent-file.nwb") + + +def test_add_to_lock(common_nwbfile, lockfile, mini_copy_name): + common_nwbfile.Nwbfile.add_to_lock(mini_copy_name) + with lockfile.open("r") as f: + assert mini_copy_name in f.read() + + with pytest.raises(AssertionError): + common_nwbfile.Nwbfile.add_to_lock("non-existent-file.nwb") + + +def test_nwbfile_cleanup(common_nwbfile): + before = len(common_nwbfile.Nwbfile.fetch()) + common_nwbfile.Nwbfile.cleanup(delete_files=False) + after = len(common_nwbfile.Nwbfile.fetch()) + assert before == after, "Nwbfile cleanup changed table entry count." diff --git a/tests/common/test_position.py b/tests/common/test_position.py new file mode 100644 index 000000000..47f285977 --- /dev/null +++ b/tests/common/test_position.py @@ -0,0 +1,151 @@ +import pytest +from datajoint.hash import key_hash + + +@pytest.fixture +def common_position(common): + yield common.common_position + + +@pytest.fixture +def interval_position_info(common_position): + yield common_position.IntervalPositionInfo + + +@pytest.fixture +def default_param_key(): + yield {"position_info_param_name": "default"} + + +@pytest.fixture +def interval_key(common): + yield (common.IntervalList & "interval_list_name LIKE 'pos 0%'").fetch1( + "KEY" + ) + + +@pytest.fixture +def param_table(common_position, default_param_key, teardown): + param_table = common_position.PositionInfoParameters() + param_table.insert1(default_param_key, skip_duplicates=True) + yield param_table + if teardown: + param_table.delete(safemode=False) + + +@pytest.fixture +def upsample_position( + common, + common_position, + param_table, + default_param_key, + teardown, + interval_key, +): + params = (param_table & default_param_key).fetch1() + upsample_param_key = {"position_info_param_name": "upsampled"} + param_table.insert1( + { + **params, + **upsample_param_key, + "is_upsampled": 1, + "max_separation": 80, + "upsampling_sampling_rate": 500, + }, + skip_duplicates=True, + ) + interval_pos_key = {**interval_key, **upsample_param_key} + common_position.IntervalPositionInfoSelection.insert1( + interval_pos_key, skip_duplicates=True + ) + common_position.IntervalPositionInfo.populate(interval_pos_key) + yield interval_pos_key + if teardown: + (param_table & upsample_param_key).delete(safemode=False) + + +@pytest.fixture +def interval_pos_key(upsample_position): + yield upsample_position + + +def test_interval_position_info_insert(common_position, interval_pos_key): + assert common_position.IntervalPositionInfo & interval_pos_key + + +@pytest.fixture +def upsample_position_error( + upsample_position, + default_param_key, + param_table, + common, + common_position, + teardown, + interval_key, +): + params = (param_table & default_param_key).fetch1() + upsample_param_key = {"position_info_param_name": "upsampled error"} + param_table.insert1( + { + **params, + **upsample_param_key, + "is_upsampled": 1, + "max_separation": 1, + "upsampling_sampling_rate": 500, + }, + skip_duplicates=True, + ) + interval_pos_key = {**interval_key, **upsample_param_key} + common_position.IntervalPositionInfoSelection.insert1(interval_pos_key) + yield interval_pos_key + if teardown: + (param_table & upsample_param_key).delete(safemode=False) + + +def test_interval_position_info_insert_error( + interval_position_info, upsample_position_error +): + with pytest.raises(ValueError): + interval_position_info.populate(upsample_position_error) + + +def test_fetch1_dataframe(interval_position_info, interval_pos_key): + df = (interval_position_info & interval_pos_key).fetch1_dataframe() + err_msg = "Unexpected output of IntervalPositionInfo.fetch1_dataframe" + assert df.shape == (5193, 6), err_msg + + df_sums = {c: df[c].iloc[:5].sum() for c in df.columns} + df_sums_exp = { + "head_orientation": 4.4300073600180125, + "head_position_x": 111.25, + "head_position_y": 141.75, + "head_speed": 0.6084872579024899, + "head_velocity_x": -0.4329520555149495, + "head_velocity_y": 0.42756198762527325, + } + for k in df_sums: + assert k in df_sums_exp, err_msg + assert df_sums[k] == pytest.approx(df_sums_exp[k], rel=0.02), err_msg + + +def test_interval_position_info_kwarg_error(interval_position_info): + with pytest.raises(ValueError): + interval_position_info._fix_kwargs() + + +def test_interval_position_info_kwarg_alias(interval_position_info): + in_tuple = (0, 1, 2, 3) + out_tuple = interval_position_info._fix_kwargs( + head_orient_smoothing_std_dev=in_tuple[0], + head_speed_smoothing_std_dev=in_tuple[1], + max_separation=in_tuple[2], + max_speed=in_tuple[3], + ) + assert ( + out_tuple == in_tuple + ), "IntervalPositionInfo._fix_kwargs() should alias old arg names." + + +@pytest.mark.skip(reason="Not testing with video data yet.") +def test_position_video(common_position): + pass diff --git a/tests/common/test_region.py b/tests/common/test_region.py new file mode 100644 index 000000000..95f62fe1b --- /dev/null +++ b/tests/common/test_region.py @@ -0,0 +1,29 @@ +import pytest +from datajoint import U as dj_U + + +@pytest.fixture +def region_dict(): + yield dict(region_name="test_region") + + +@pytest.fixture +def brain_region(common, region_dict): + brain_region = common.common_region.BrainRegion() + (brain_region & "region_id > 1").delete(safemode=False) + yield brain_region + (brain_region & "region_id > 1").delete(safemode=False) + + +def test_region_add(brain_region, region_dict): + next_id = ( + dj_U().aggr(brain_region, n="max(region_id)").fetch1("n") or 0 + ) + 1 + region_id = brain_region.fetch_add( + **region_dict, + subregion_name="test_subregion_add", + subsubregion_name="test_subsubregion_add", + ) + assert ( + region_id == next_id + ), "Region.fetch_add() should autincrement region_id." diff --git a/tests/common/test_ripple.py b/tests/common/test_ripple.py new file mode 100644 index 000000000..71a57d022 --- /dev/null +++ b/tests/common/test_ripple.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.mark.skip(reason="Not testing V0: common_ripple") +def test_common_ripple(common): + pass diff --git a/tests/common/test_sensors.py b/tests/common/test_sensors.py new file mode 100644 index 000000000..9cdedeeb4 --- /dev/null +++ b/tests/common/test_sensors.py @@ -0,0 +1,21 @@ +import pytest + + +@pytest.fixture +def sensor_data(common, mini_insert): + tbl = common.common_sensors.SensorData() + tbl.populate() + yield tbl + + +def test_sensor_data_insert(sensor_data, mini_insert, mini_restr, mini_content): + obj_fetch = (sensor_data & mini_restr).fetch1("sensor_data_object_id") + obj_raw = ( + mini_content.processing["analog"] + .data_interfaces["analog"] + .time_series["analog"] + .object_id + ) + assert ( + obj_fetch == obj_raw + ), "SensorData object_id does not match raw object_id." diff --git a/tests/common/test_session.py b/tests/common/test_session.py new file mode 100644 index 000000000..6e0a8f0ce --- /dev/null +++ b/tests/common/test_session.py @@ -0,0 +1,81 @@ +import pytest +from datajoint.errors import DataJointError + + +@pytest.fixture +def common_session(common): + return common.common_session + + +@pytest.fixture +def group_name_dict(): + return {"session_group_name": "group1"} + + +@pytest.fixture +def add_session_group(common_session, group_name_dict): + session_group = common_session.SessionGroup() + session_group_dict = { + **group_name_dict, + "session_group_description": "group1 description", + } + session_group.add_group(**session_group_dict, skip_duplicates=True) + session_group_dict["session_group_description"] = "updated description" + session_group.update_session_group_description(**session_group_dict) + yield session_group, session_group_dict + + +@pytest.fixture +def session_group(add_session_group): + yield add_session_group[0] + + +@pytest.fixture +def session_group_dict(add_session_group): + yield add_session_group[1] + + +def test_session_group_add(session_group, session_group_dict): + assert session_group & session_group_dict, "Session group not added" + + +@pytest.fixture +def add_session_to_group(session_group, mini_copy_name, group_name_dict): + session_group.add_session_to_group( + nwb_file_name=mini_copy_name, **group_name_dict + ) + + +def test_addremove_session_group( + common_session, + session_group, + session_group_dict, + group_name_dict, + mini_copy_name, + add_session_to_group, + add_session_group, +): + assert session_group & session_group_dict, "Session not added to group" + + session_group.remove_session_from_group( + nwb_file_name=mini_copy_name, + safemode=False, + **group_name_dict, + ) + assert ( + len(common_session.SessionGroupSession & session_group_dict) == 0 + ), "SessionGroupSession not removed from by helper function" + + +def test_get_group_sessions( + session_group, group_name_dict, add_session_to_group +): + ret = session_group.get_group_sessions(**group_name_dict) + assert len(ret) == 1, "Incorrect number of sessions returned" + + +def test_delete_group_error(session_group, group_name_dict): + session_group.delete_group(**group_name_dict, safemode=False) + assert ( + len(session_group & group_name_dict) == 0 + ), "Group not deleted by helper function" diff --git a/tests/conftest.py b/tests/conftest.py index ac1539abf..3c2bc866b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,80 +1,326 @@ -# directory-specific hook implementations import os -import shutil import sys -import tempfile +import warnings +from contextlib import nullcontext +from pathlib import Path +from subprocess import Popen +from time import sleep as tsleep import datajoint as dj +import pynwb +import pytest +from datajoint.logging import logger as dj_logger -from .datajoint._config import DATAJOINT_SERVER_PORT -from .datajoint._datajoint_server import ( - kill_datajoint_server, - run_datajoint_server, -) +from .container import DockerMySQLManager -thisdir = os.path.dirname(os.path.realpath(__file__)) -sys.path.append(thisdir) +# ---------------------- CONSTANTS --------------------- - -global __PROCESS -__PROCESS = None +# globals in pytest_configure: +# BASE_DIR, RAW_DIR, SERVER, TEARDOWN, VERBOSE, TEST_FILE, DOWNLOAD +warnings.filterwarnings("ignore", category=UserWarning, module="hdmf") def pytest_addoption(parser): + """Permit constants when calling pytest at command line + + Example + ------- + > pytest --quiet-spy + + Parameters + ---------- + --quiet-spy (bool): Default False. Allow print statements from Spyglass. + --no-teardown (bool): Default False. Delete pipeline on close. + --no-server (bool): Default False. Run datajoint server in Docker. + --datadir (str): Default './tests/test_data/'. Dir for local input file. + WARNING: not yet implemented. + """ + parser.addoption( + "--quiet-spy", + action="store_true", + dest="quiet_spy", + default=False, + help="Quiet logging from Spyglass.", + ) parser.addoption( - "--current", + "--no-server", action="store_true", - dest="current", + dest="no_server", default=False, - help="run only tests marked as current", + help="Do not launch datajoint server in Docker.", + ) + parser.addoption( + "--no-teardown", + action="store_true", + default=False, + dest="no_teardown", + help="Tear down tables after tests.", + ) + parser.addoption( + "--base-dir", + action="store", + default="./tests/_data/", + dest="base_dir", + help="Directory for local input file.", ) def pytest_configure(config): - config.addinivalue_line( - "markers", "current: for convenience -- mark one test as current" + global BASE_DIR, RAW_DIR, SERVER, TEARDOWN, VERBOSE, TEST_FILE, DOWNLOAD + + TEST_FILE = "minirec20230622.nwb" + TEARDOWN = not config.option.no_teardown + VERBOSE = not config.option.quiet_spy + + BASE_DIR = Path(config.option.base_dir).absolute() + BASE_DIR.mkdir(parents=True, exist_ok=True) + RAW_DIR = BASE_DIR / "raw" + os.environ["SPYGLASS_BASE_DIR"] = str(BASE_DIR) + + SERVER = DockerMySQLManager( + restart=False, + shutdown=TEARDOWN, + null_server=config.option.no_server, + verbose=VERBOSE, ) + DOWNLOAD = download_data(verbose=VERBOSE) + - markexpr_list = [] +def data_is_downloaded(): + """Check if data is downloaded.""" + return os.path.exists(RAW_DIR / TEST_FILE) - if config.option.current: - markexpr_list.append("current") - if len(markexpr_list) > 0: - markexpr = " and ".join(markexpr_list) - setattr(config.option, "markexpr", markexpr) +def download_data(verbose=False): + """Download data from BOX using environment variable credentials. - _set_env() + Note: In gh-actions, this is handled by the test-conda workflow. + """ + if data_is_downloaded(): + return None + UCSF_BOX_USER = os.environ.get("UCSF_BOX_USER") + UCSF_BOX_TOKEN = os.environ.get("UCSF_BOX_TOKEN") + if not all([UCSF_BOX_USER, UCSF_BOX_TOKEN]): + raise ValueError( + "Missing data, no credentials: UCSF_BOX_USER or UCSF_BOX_TOKEN." + ) + data_url = f"ftps://ftp.box.com/trodes_to_nwb_test_data/{TEST_FILE}" - # note that in this configuration, every test will use the same datajoint - # server this may create conflicts and dependencies between tests it may be - # better but significantly slower to start a new server for every test but - # the server needs to be started before tests are collected because - # datajoint runs when the source files are loaded, not when the tests are - # run. one solution might be to restart the server after every test + cmd = [ + "wget", + "--recursive", + "--no-host-directories", + "--no-directories", + "--user", + UCSF_BOX_USER, + "--password", + UCSF_BOX_TOKEN, + "-P", + RAW_DIR, + data_url, + ] + if not verbose: + cmd.insert(cmd.index("--recursive") + 1, "--no-verbose") + cmd_kwargs = dict(stdout=sys.stdout, stderr=sys.stderr) if verbose else {} - global __PROCESS - __PROCESS = run_datajoint_server() + return Popen(cmd, **cmd_kwargs) def pytest_unconfigure(config): - if __PROCESS: - print("Terminating datajoint compute resource process") - __PROCESS.terminate() - # TODO handle ResourceWarning: subprocess X is still running - # __PROCESS.join() + if TEARDOWN: + SERVER.stop() + + +# ------------------- FIXTURES ------------------- + + +@pytest.fixture(scope="session") +def verbose(): + """Config for pytest fixtures.""" + yield VERBOSE + + +@pytest.fixture(scope="session", autouse=True) +def verbose_context(verbose): + """Verbosity context for suppressing Spyglass logging.""" + yield nullcontext() if verbose else QuietStdOut() + + +@pytest.fixture(scope="session") +def teardown(request): + yield TEARDOWN + + +@pytest.fixture(scope="session") +def server(request, teardown): + SERVER.wait() + yield SERVER + if teardown: + SERVER.stop() + + +@pytest.fixture(scope="session") +def dj_conn(request, server, verbose, teardown): + """Fixture for datajoint connection.""" + config_file = "dj_local_conf.json_pytest" + + dj.config.update(server.creds) + dj.config["loglevel"] = "INFO" if verbose else "ERROR" + dj.config.save(config_file) + dj.conn() + yield dj.conn() + if teardown: + if Path(config_file).exists(): + os.remove(config_file) + + +@pytest.fixture(scope="session") +def base_dir(): + yield BASE_DIR + + +@pytest.fixture(scope="session") +def raw_dir(base_dir): + # could do settings.raw_dir, but this is faster while server booting + yield base_dir / "raw" + + +@pytest.fixture(scope="session") +def mini_path(raw_dir): + path = raw_dir / TEST_FILE + + # wait for wget download to finish + if DOWNLOAD is not None: + DOWNLOAD.wait() + + # wait for gh-actions download to finish + timeout, wait, found = 60, 5, False + for _ in range(timeout // wait): + if path.exists(): + found = True + break + tsleep(wait) + + if not found: + raise ConnectionError("Download failed.") + + yield path + + +@pytest.fixture(scope="session") +def mini_copy_name(mini_path): + from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename # noqa: E402 + + yield get_nwb_copy_filename(mini_path).split("/")[-1] + + +@pytest.fixture(scope="session") +def mini_content(mini_path): + with pynwb.NWBHDF5IO( + path=str(mini_path), mode="r", load_namespaces=True + ) as io: + nwbfile = io.read() + assert nwbfile is not None, "NWBFile empty." + yield nwbfile + + +@pytest.fixture(scope="session") +def mini_open(mini_content): + yield mini_content + + +@pytest.fixture(scope="session") +def mini_closed(mini_path): + with pynwb.NWBHDF5IO( + path=str(mini_path), mode="r", load_namespaces=True + ) as io: + nwbfile = io.read() + yield nwbfile + + +@pytest.fixture(autouse=True, scope="session") +def mini_insert(mini_path, teardown, server, dj_conn): + from spyglass.common import Nwbfile, Session # noqa: E402 + from spyglass.data_import import insert_sessions # noqa: E402 + from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402 + + dj_logger.info("Inserting test data.") + + if not server.connected: + raise ConnectionError("No server connection.") + + if len(Nwbfile()) != 0: + dj_logger.warning("Skipping insert, use existing data.") + else: + insert_sessions(mini_path.name) + + if len(Session()) == 0: + raise ValueError("No sessions inserted.") + + yield + + close_nwb_files() + # Note: no need to run deletes in teardown, since we are using teardown + # will remove the container + + +@pytest.fixture(scope="session") +def mini_restr(mini_path): + yield f"nwb_file_name LIKE '{mini_path.stem}%'" + + +@pytest.fixture(scope="session") +def mini_dict(mini_copy_name): + yield {"nwb_file_name": mini_copy_name} + + +@pytest.fixture(scope="session") +def common(dj_conn): + from spyglass import common + + yield common + + +@pytest.fixture(scope="session") +def data_import(dj_conn): + from spyglass import data_import + + yield data_import + + +@pytest.fixture(scope="session") +def settings(dj_conn): + from spyglass import settings + + yield settings + + +@pytest.fixture(scope="session") +def populate_exception(): + from spyglass.common.errors import PopulateException + + yield PopulateException + + +# ------------------ GENERAL FUNCTION ------------------ + - kill_datajoint_server() - shutil.rmtree(os.environ["SPYGLASS_BASE_DIR"]) +class QuietStdOut: + """If quiet_spy, used to quiet prints, teardowns and table.delete prints""" + def __init__(self): + from spyglass.utils import logger as spyglass_logger -def _set_env(): - """Set environment variables.""" - print("Setting datajoint and kachery environment variables.") + self.spy_logger = spyglass_logger + self.previous_level = None - os.environ["SPYGLASS_BASE_DIR"] = str(tempfile.mkdtemp()) + def __enter__(self): + self.previous_level = self.spy_logger.getEffectiveLevel() + self.spy_logger.setLevel("CRITICAL") + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") - dj.config["database.host"] = "localhost" - dj.config["database.port"] = DATAJOINT_SERVER_PORT - dj.config["database.user"] = "root" - dj.config["database.password"] = "tutorial" + def __exit__(self, exc_type, exc_val, exc_tb): + self.spy_logger.setLevel(self.previous_level) + sys.stdout.close() + sys.stdout = self._original_stdout diff --git a/tests/container.py b/tests/container.py new file mode 100644 index 000000000..df820f1d0 --- /dev/null +++ b/tests/container.py @@ -0,0 +1,216 @@ +import atexit +import time + +import datajoint as dj +import docker +from datajoint import logger + + +class DockerMySQLManager: + """Manage Docker container for MySQL server + + Parameters + ---------- + image_name : str + Docker image name. Default 'datajoint/mysql'. + mysql_version : str + MySQL version. Default '8.0'. + container_name : str + Docker container name. Default 'spyglass-pytest'. + port : str + Port to map to DJ's default 3306. Default '330[mysql_version]' + (i.e., 3308 if testing 8.0). + null_server : bool + If True, do not start container. Return on all methods. Default False. + Useful for iterating on tests in existing container. + restart : bool + If True, stop and remove existing container on startup. Default True. + shutdown : bool + If True, stop and remove container on exit from python. Default True. + verbose : bool + If True, print container status on startup. Default False. + """ + + def __init__( + self, + image_name="datajoint/mysql", + mysql_version="8.0", + container_name="spyglass-pytest", + port=None, + null_server=False, + restart=True, + shutdown=True, + verbose=False, + ) -> None: + self.image_name = image_name + self.mysql_version = mysql_version + self.container_name = container_name + self.port = port or "330" + self.mysql_version[0] + self.client = docker.from_env() + self.null_server = null_server + self.password = "tutorial" + self.user = "root" + self.host = "localhost" + self._ran_container = None + self.logger = logger + self.logger.setLevel("INFO" if verbose else "ERROR") + + if not self.null_server: + if shutdown: + atexit.register(self.stop) # stop container on python exit + if restart: + self.stop() # stop container if it exists + self.start() + + @property + def container(self) -> docker.models.containers.Container: + return self.client.containers.get(self.container_name) + + @property + def container_status(self) -> str: + try: + self.container.reload() + return self.container.status + except docker.errors.NotFound: + return None + + @property + def container_health(self) -> str: + try: + self.container.reload() + return self.container.health + except docker.errors.NotFound: + return None + + @property + def msg(self) -> str: + return f"Container {self.container_name} " + + def start(self) -> str: + if self.null_server: + return None + + elif self.container_status in ["created", "running", "restarting"]: + self.logger.info( + self.msg + "starting: " + self.container_status + "." + ) + + elif self.container_status == "exited": + self.logger.info(self.msg + "restarting.") + self.container.restart() + + else: + self._ran_container = self.client.containers.run( + image=f"{self.image_name}:{self.mysql_version}", + name=self.container_name, + ports={3306: self.port}, + environment=[ + f"MYSQL_ROOT_PASSWORD={self.password}", + "MYSQL_DEFAULT_STORAGE_ENGINE=InnoDB", + ], + detach=True, + tty=True, + ) + self.logger.info(self.msg + "starting new.") + + return self.container.name + + def wait(self, timeout=120, wait=5) -> None: + """Wait for healthy container. + + Parameters + ---------- + timeout : int + Timeout in seconds. Default 120. + wait : int + Time to wait between checks in seconds. Default 5. + """ + + if self.null_server: + return None + if not self.container_status or self.container_status == "exited": + self.start() + + for i in range(timeout // wait): + if self.container.health == "healthy": + break + self.logger.info(f"Container {self.container_name} starting... {i}") + time.sleep(wait) + self.logger.info( + f"Container {self.container_name}, {self.container.health}." + ) + + @property + def _add_sql(self) -> str: + ESC = r"\_%" + return ( + "CREATE USER IF NOT EXISTS 'basic'@'%' IDENTIFIED BY " + + f"'{self.password}'; GRANT USAGE ON `%`.* TO 'basic'@'%';" + + "GRANT SELECT ON `%`.* TO 'basic'@'%';" + + f"GRANT ALL PRIVILEGES ON `common{ESC}`.* TO `basic`@`%`;" + + f"GRANT ALL PRIVILEGES ON `spikesorting{ESC}`.* TO `basic`@`%`;" + + f"GRANT ALL PRIVILEGES ON `lfp{ESC}`.* TO `basic`@`%`;" + + f"GRANT ALL PRIVILEGES ON `position{ESC}`.* TO `basic`@`%`;" + + f"GRANT ALL PRIVILEGES ON `ripple{ESC}`.* TO `basic`@`%`;" + + f"GRANT ALL PRIVILEGES ON `linearization{ESC}`.* TO `basic`@`%`;" + ).strip() + + def add_user(self) -> int: + """Add 'basic' user to container.""" + if self.null_server: + return None + + if self._container_running(): + result = self.container.exec_run( + cmd=[ + "mysql", + "-u", + self.user, + f"--password={self.password}", + "-e", + self._add_sql, + ], + stdout=False, + stderr=False, + tty=True, + ) + if result.exit_code == 0: + self.logger.info("Container added user.") + else: + logger.error("Failed to add user.") + return result.exit_code + else: + logger.error(f"Container {self.container_name} does not exist.") + return None + + @property + def creds(self): + """Datajoint credentials for this container.""" + return { + "database.host": "localhost", + "database.password": self.password, + "database.user": self.user, + "database.port": int(self.port), + "safmode": "false", + "custom": {"test_mode": True}, + } + + @property + def connected(self) -> bool: + self.wait() + dj.config.update(self.creds) + return dj.conn().is_connected + + def stop(self, remove=True) -> None: + """Stop and remove container.""" + if self.null_server: + return None + if not self.container_status or self.container_status == "exited": + return + + self.container.stop() + self.logger.info(f"Container {self.container_name} stopped.") + + if remove: + self.container.remove() + self.logger.info(f"Container {self.container_name} removed.") diff --git a/tests/data_import/__init__.py b/tests/data_import/__init__.py index e69de29bb..8f7eaee37 100644 --- a/tests/data_import/__init__.py +++ b/tests/data_import/__init__.py @@ -0,0 +1,3 @@ +# NOTE: test_insert_sessions does not increase coverage over common/test_insert +# but it does declare it's own nwbfile without downloading and test broken +# links which aren't technically part of spyglass diff --git a/tests/data_import/test_insert_sessions.py b/tests/data_import/test_insert_sessions.py index d7968d164..7c125ed6b 100644 --- a/tests/data_import/test_insert_sessions.py +++ b/tests/data_import/test_insert_sessions.py @@ -1,104 +1,39 @@ -import datetime -import os -import pathlib import shutil +import warnings +from pathlib import Path -import datajoint as dj import pynwb import pytest from hdmf.backends.warnings import BrokenLinkWarning -from spyglass.data_import.insert_sessions import copy_nwb_link_raw_ephys -from spyglass.settings import raw_dir - -@pytest.fixture() -def new_nwbfile_raw_file_name(tmp_path): - nwbfile = pynwb.NWBFile( - session_description="session_description", - identifier="identifier", - session_start_time=datetime.datetime.now(datetime.timezone.utc), - ) - - device = nwbfile.create_device("dev1") - group = nwbfile.create_electrode_group( - "tetrode1", "tetrode description", "tetrode location", device - ) - nwbfile.add_electrode( - id=1, - x=1.0, - y=2.0, - z=3.0, - imp=-1.0, - location="CA1", - filtering="none", - group=group, - group_name="tetrode1", +@pytest.fixture(scope="session") +def copy_nwb_link_raw_ephys(data_import): + from spyglass.data_import.insert_sessions import ( # noqa: E402 + copy_nwb_link_raw_ephys, ) - region = nwbfile.create_electrode_table_region( - region=[0], description="electrode 1" - ) - - es = pynwb.ecephys.ElectricalSeries( - name="test_ts", - data=[1, 2, 3], - timestamps=[1.0, 2.0, 3.0], - electrodes=region, - ) - nwbfile.add_acquisition(es) - - _ = tmp_path # CBroz: Changed to match testing base directory - file_name = "raw.nwb" - file_path = raw_dir + "/" + file_name + return copy_nwb_link_raw_ephys - with pynwb.NWBHDF5IO(str(file_path), mode="w") as io: - io.write(nwbfile) - return file_name +def test_open_path(mini_path, mini_open): + this_acq = mini_open.acquisition + assert "e-series" in this_acq, "Ephys link no longer exists" + assert ( + str(mini_path) == this_acq["e-series"].data.file.filename + ), "Path of ephys link is incorrect" -@pytest.fixture() -def new_nwbfile_no_ephys_file_name(): - return "raw_no_ephys.nwb" - - -@pytest.fixture() -def moved_nwbfile_no_ephys_file_path(tmp_path, new_nwbfile_no_ephys_file_name): - return tmp_path / new_nwbfile_no_ephys_file_name - - -def test_copy_nwb( - new_nwbfile_raw_file_name, - new_nwbfile_no_ephys_file_name, - moved_nwbfile_no_ephys_file_path, -): - copy_nwb_link_raw_ephys( - new_nwbfile_raw_file_name, new_nwbfile_no_ephys_file_name - ) - - # new file should not have ephys data - base_dir = pathlib.Path(os.getenv("SPYGLASS_BASE_DIR", None)) - new_nwbfile_raw_file_name_abspath = ( - base_dir / "raw" / new_nwbfile_raw_file_name - ) - out_nwb_file_abspath = base_dir / "raw" / new_nwbfile_no_ephys_file_name - with pynwb.NWBHDF5IO(path=str(out_nwb_file_abspath), mode="r") as io: - nwbfile = io.read() - assert ( - "test_ts" in nwbfile.acquisition - ) # this still exists but should be a link now - assert nwbfile.acquisition["test_ts"].data.file.filename == str( - new_nwbfile_raw_file_name_abspath - ) - # test readability after moving the linking raw file (paths are stored as - # relative paths in NWB) so this should break the link (moving the linked-to - # file should also break the link) +def test_copy_link(mini_path, settings, mini_closed, copy_nwb_link_raw_ephys): + """Test readability after moving the linking raw file, breaking link""" + new_path = Path(settings.raw_dir) / "no_ephys.nwb" + new_moved = Path(settings.temp_dir) / "no_ephys_moved.nwb" - shutil.move(out_nwb_file_abspath, moved_nwbfile_no_ephys_file_path) - with pynwb.NWBHDF5IO( - path=str(moved_nwbfile_no_ephys_file_path), mode="r" - ) as io: - with pytest.warns(BrokenLinkWarning): - nwbfile = io.read() # should raise BrokenLinkWarning - assert "test_ts" not in nwbfile.acquisition + copy_nwb_link_raw_ephys(mini_path.name, new_path.name) + shutil.move(new_path, new_moved) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + with pynwb.NWBHDF5IO(path=str(new_moved), mode="r") as io: + with pytest.warns(BrokenLinkWarning): + nwb_acq = io.read().acquisition + assert "e-series" not in nwb_acq, "Ephys link still exists after move" diff --git a/tests/datajoint/_config.py b/tests/datajoint/_config.py deleted file mode 100644 index 3798427ea..000000000 --- a/tests/datajoint/_config.py +++ /dev/null @@ -1 +0,0 @@ -DATAJOINT_SERVER_PORT = 3307 diff --git a/tests/datajoint/_datajoint_server.py b/tests/datajoint/_datajoint_server.py deleted file mode 100644 index f12455e67..000000000 --- a/tests/datajoint/_datajoint_server.py +++ /dev/null @@ -1,110 +0,0 @@ -import multiprocessing -import os -import time -import traceback - -import kachery_client as kc -from pymysql.err import OperationalError - -from ._config import DATAJOINT_SERVER_PORT - -DOCKER_IMAGE_NAME = "datajoint-server-pytest" - - -def run_service_datajoint_server(): - # The following cleanup is needed because we terminate this compute resource process - # See: https://pytest-cov.readthedocs.io/en/latest/subprocess-support.html - from pytest_cov.embed import cleanup_on_sigterm - - cleanup_on_sigterm() - - os.environ["RUNNING_PYTEST"] = "TRUE" - - ss = kc.ShellScript( - f""" - #!/bin/bash - set -ex - - docker kill {DOCKER_IMAGE_NAME} > /dev/null 2>&1 || true - docker rm {DOCKER_IMAGE_NAME} > /dev/null 2>&1 || true - exec docker run --name {DOCKER_IMAGE_NAME} -e MYSQL_ROOT_PASSWORD=tutorial -p {DATAJOINT_SERVER_PORT}:3306 datajoint/mysql - """, - redirect_output_to_stdout=True, - ) # noqa: E501 - ss.start() - ss.wait() - - -def run_datajoint_server(): - print("Starting datajoint server") - - ss_pull = kc.ShellScript( - """ - #!/bin/bash - set -ex - - exec docker pull datajoint/mysql - """ - ) - ss_pull.start() - ss_pull.wait() - - process = multiprocessing.Process( - target=run_service_datajoint_server, kwargs=dict() - ) - process.start() - - try: - _wait_for_datajoint_server_to_start() - except Exception: - kill_datajoint_server() - raise - - return process - # yield process - - # process.terminate() - # kill_datajoint_server() - - -def kill_datajoint_server(): - print("Terminating datajoint server") - - ss2 = kc.ShellScript( - f""" - #!/bin/bash - - set -ex - - docker kill {DOCKER_IMAGE_NAME} || true - docker rm {DOCKER_IMAGE_NAME} - """ - ) - ss2.start() - ss2.wait() - - -def _wait_for_datajoint_server_to_start(): - time.sleep(15) # it takes a while to start the server - timer = time.time() - print("Waiting for DataJoint server to start. Time", timer) - while True: - try: - from spyglass.common import Session # noqa: F401 - - return - except OperationalError as e: # e.g. Connection Error - print("DataJoint server not yet started. Time", time.time()) - print(e) - except Exception: - print("Failed to import Session. Time", time.time()) - print(traceback.format_exc()) - current_time = time.time() - elapsed = current_time - timer - if elapsed > 300: - raise Exception( - "Timeout while waiting for datajoint server to start and " - "import Session to succeed. Time", - current_time, - ) - time.sleep(5) diff --git a/tests/lfp/conftest.py b/tests/lfp/conftest.py new file mode 100644 index 000000000..2eb511265 --- /dev/null +++ b/tests/lfp/conftest.py @@ -0,0 +1,215 @@ +import numpy as np +import pytest +from pynwb import NWBHDF5IO + + +@pytest.fixture(scope="session") +def lfp(common): + from spyglass import lfp + + return lfp + + +@pytest.fixture(scope="session") +def lfp_band(lfp): + from spyglass.lfp.analysis.v1 import lfp_band + + return lfp_band + + +@pytest.fixture(scope="session") +def firfilters_table(common): + return common.FirFilterParameters() + + +@pytest.fixture(scope="session") +def electrodegroup_table(lfp): + return lfp.v1.LFPElectrodeGroup() + + +@pytest.fixture(scope="session") +def lfp_constants(common, mini_copy_name, mini_dict): + n_delay = 9 + lfp_electrode_group_name = "test" + orig_list_name = "01_s1" + orig_valid_times = ( + common.IntervalList + & mini_dict + & f"interval_list_name = '{orig_list_name}'" + ).fetch1("valid_times") + new_list_name = orig_list_name + f"_first{n_delay}" + new_list_key = { + "nwb_file_name": mini_copy_name, + "interval_list_name": new_list_name, + "valid_times": np.asarray( + [[orig_valid_times[0, 0], orig_valid_times[0, 0] + n_delay]] + ), + } + + yield dict( + lfp_electrode_ids=[0], + lfp_electrode_group_name=lfp_electrode_group_name, + lfp_eg_key={ + "nwb_file_name": mini_copy_name, + "lfp_electrode_group_name": lfp_electrode_group_name, + }, + n_delay=n_delay, + orig_interval_list_name=orig_list_name, + orig_valid_times=orig_valid_times, + interval_list_name=new_list_name, + interval_key=new_list_key, + filter1_name="LFP 0-400 Hz", + filter_sampling_rate=30_000, + filter2_name="Theta 5-11 Hz", + lfp_band_electrode_ids=[0], # assumes we've filtered these electrodes + lfp_band_sampling_rate=100, # desired sampling rate + ) + + +@pytest.fixture(scope="session") +def add_electrode_group( + firfilters_table, + electrodegroup_table, + mini_copy_name, + lfp_constants, +): + firfilters_table.create_standard_filters() + group_name = lfp_constants.get("lfp_electrode_group_name") + electrodegroup_table.create_lfp_electrode_group( + nwb_file_name=mini_copy_name, + group_name=group_name, + electrode_list=lfp_constants.get("lfp_electrode_ids"), + ) + assert len( + electrodegroup_table & {"lfp_electrode_group_name": group_name} + ), "Failed to add LFPElectrodeGroup." + yield + + +@pytest.fixture(scope="session") +def add_interval(common, lfp_constants): + common.IntervalList.insert1( + lfp_constants.get("interval_key"), skip_duplicates=True + ) + yield lfp_constants.get("interval_list_name") + + +@pytest.fixture(scope="session") +def add_selection( + lfp, common, add_electrode_group, add_interval, lfp_constants +): + lfp_s_key = { + **lfp_constants.get("lfp_eg_key"), + "target_interval_list_name": add_interval, + "filter_name": lfp_constants.get("filter1_name"), + "filter_sampling_rate": lfp_constants.get("filter_sampling_rate"), + } + lfp.v1.LFPSelection.insert1(lfp_s_key, skip_duplicates=True) + yield lfp_s_key + + +@pytest.fixture(scope="session") +def lfp_s_key(lfp_constants, mini_copy_name): + yield { + "nwb_file_name": mini_copy_name, + "lfp_electrode_group_name": lfp_constants.get( + "lfp_electrode_group_name" + ), + "target_interval_list_name": lfp_constants.get("interval_list_name"), + } + + +@pytest.fixture(scope="session") +def populate_lfp(lfp, add_selection, lfp_s_key): + lfp.v1.LFPV1().populate(add_selection) + yield {"merge_id": (lfp.LFPOutput.LFPV1() & lfp_s_key).fetch1("merge_id")} + + +@pytest.fixture(scope="session") +def lfp_merge_key(populate_lfp): + yield populate_lfp + + +@pytest.fixture(scope="module") +def lfp_analysis_raw(common, lfp, populate_lfp, mini_dict): + abs_path = (common.AnalysisNwbfile * lfp.v1.LFPV1 & mini_dict).fetch( + "analysis_file_abs_path" + )[0] + assert abs_path is not None, "No NWBFile found." + with NWBHDF5IO(path=str(abs_path), mode="r", load_namespaces=True) as io: + nwbfile = io.read() + assert nwbfile is not None, "NWBFile empty." + yield nwbfile + + +@pytest.fixture(scope="session") +def lfp_band_sampling_rate(lfp, lfp_merge_key): + yield lfp.LFPOutput.merge_get_parent(lfp_merge_key).fetch1( + "lfp_sampling_rate" + ) + + +@pytest.fixture(scope="session") +def add_band_filter(common, lfp_constants, lfp_band_sampling_rate): + filter_name = lfp_constants.get("filter2_name") + common.FirFilterParameters().add_filter( + filter_name, + lfp_band_sampling_rate, + "bandpass", + [4, 5, 11, 12], + "theta filter for 1 Khz data", + ) + yield lfp_constants.get("filter2_name") + + +@pytest.fixture(scope="session") +def add_band_selection( + lfp_band, + mini_copy_name, + mini_dict, + lfp_merge_key, + add_interval, + lfp_constants, + add_band_filter, + add_electrode_group, +): + lfp_band.LFPBandSelection().set_lfp_band_electrodes( + nwb_file_name=mini_copy_name, + lfp_merge_id=lfp_merge_key.get("merge_id"), + electrode_list=lfp_constants.get("lfp_band_electrode_ids"), + filter_name=add_band_filter, + interval_list_name=add_interval, + reference_electrode_list=[-1], + lfp_band_sampling_rate=lfp_constants.get("lfp_band_sampling_rate"), + ) + yield (lfp_band.LFPBandSelection & mini_dict).fetch1("KEY") + + +@pytest.fixture(scope="session") +def lfp_band_key(add_band_selection): + yield add_band_selection + + +@pytest.fixture(scope="session") +def populate_lfp_band(lfp_band, add_band_selection): + lfp_band.LFPBandV1().populate(add_band_selection) + yield + + +# @pytest.fixture(scope="session") +# def mini_eseries(common, mini_copy_name): +# yield (common.Raw() & {"nwb_file_name": mini_copy_name}).fetch_nwb()[0][ +# "raw" +# ] + + +@pytest.fixture(scope="module") +def lfp_band_analysis_raw(common, lfp_band, populate_lfp_band, mini_dict): + abs_path = (common.AnalysisNwbfile * lfp_band.LFPBandV1 & mini_dict).fetch( + "analysis_file_abs_path" + )[0] + assert abs_path is not None, "No NWBFile found." + with NWBHDF5IO(path=str(abs_path), mode="r", load_namespaces=True) as io: + nwbfile = io.read() + assert nwbfile is not None, "NWBFile empty." + yield nwbfile diff --git a/tests/lfp/test_pipeline.py b/tests/lfp/test_pipeline.py new file mode 100644 index 000000000..86599190d --- /dev/null +++ b/tests/lfp/test_pipeline.py @@ -0,0 +1,25 @@ +from pandas import DataFrame, Index + + +def test_lfp_dataframe(common, lfp, lfp_analysis_raw, lfp_merge_key): + lfp_raw = lfp_analysis_raw.scratch["filtered data"] + df_raw = DataFrame( + lfp_raw.data, index=Index(lfp_raw.timestamps, name="time") + ) + df_fetch = (lfp.LFPOutput & lfp_merge_key).fetch1_dataframe() + + assert df_raw.equals(df_fetch), "LFP dataframe not match." + + +def test_lfp_band_dataframe(lfp_band_analysis_raw, lfp_band, lfp_band_key): + lfp_band_raw = ( + lfp_band_analysis_raw.processing["ecephys"] + .fields["data_interfaces"]["LFP"] + .electrical_series["filtered data"] + ) + df_raw = DataFrame( + lfp_band_raw.data, index=Index(lfp_band_raw.timestamps, name="time") + ) + df_fetch = (lfp_band.LFPBandV1 & lfp_band_key).fetch1_dataframe() + + assert df_raw.equals(df_fetch), "LFPBand dataframe not match." diff --git a/tests/test_insert_beans.py b/tests/test_insert_beans.py deleted file mode 100644 index d74ecb856..000000000 --- a/tests/test_insert_beans.py +++ /dev/null @@ -1,97 +0,0 @@ -from datetime import datetime -import kachery_cloud as kcl -import os -import pathlib -import pynwb -import pytest - - -@pytest.mark.skip(reason="test_path needs to be updated") -def test_insert_sessions(): - print( - "In test_insert_sessions, os.environ['SPYGLASS_BASE_DIR'] is", - os.environ["SPYGLASS_BASE_DIR"], - ) - raw_dir = pathlib.Path(os.environ["SPYGLASS_BASE_DIR"]) / "raw" - nwbfile_path = raw_dir / "test.nwb" - - from spyglass.common import ( - Session, - DataAcquisitionDevice, - CameraDevice, - Probe, - ) - from spyglass.data_import import insert_sessions - - test_path = ( - "ipfs://bafybeie4svt3paz5vr7cw7mkgibutbtbzyab4s24hqn5pzim3sgg56m3n4" - ) - try: - local_test_path = kcl.load_file(test_path) - except Exception as e: - if os.environ.get("KACHERY_CLOUD_EPHEMERAL", None) != "TRUE": - print( - "Cannot load test file in non-ephemeral mode. Kachery cloud client may need to be registered." - ) - raise e - - # move the file to spyglass raw dir - os.rename(local_test_path, nwbfile_path) - - # test that the file can be read. this is not used otherwise - with pynwb.NWBHDF5IO( - path=str(nwbfile_path), mode="r", load_namespaces=True - ) as io: - nwbfile = io.read() - assert nwbfile is not None - - insert_sessions(nwbfile_path.name) - - x = (Session() & {"nwb_file_name": "test_.nwb"}).fetch1() - assert x["nwb_file_name"] == "test_.nwb" - assert x["subject_id"] == "Beans" - assert x["institution_name"] == "University of California, San Francisco" - assert x["lab_name"] == "Loren Frank" - assert x["session_id"] == "beans_01" - assert x["session_description"] == "Reinforcement leaarning" - assert x["session_start_time"] == datetime(2019, 7, 18, 15, 29, 47) - assert x["timestamps_reference_time"] == datetime(1970, 1, 1, 0, 0) - assert x["experiment_description"] == "Reinforcement learning" - - x = DataAcquisitionDevice().fetch() - assert len(x) == 1 - assert x[0]["device_name"] == "dataacq_device0" - assert x[0]["system"] == "SpikeGadgets" - assert x[0]["amplifier"] == "Intan" - assert x[0]["adc_circuit"] == "Intan" - - x = CameraDevice().fetch() - assert len(x) == 2 - # NOTE order of insertion is not consistent so cannot use x[0] - expected1 = dict( - camera_name="beans sleep camera", - # meters_per_pixel=0.00055, # cannot check floating point values this way - manufacturer="", - model="unknown", - lens="unknown", - camera_id=0, - ) - assert CameraDevice() & expected1 - assert (CameraDevice() & expected1).fetch("meters_per_pixel") == 0.00055 - expected2 = dict( - camera_name="beans run camera", - # meters_per_pixel=0.002, - manufacturer="", - model="unknown2", - lens="unknown2", - camera_id=1, - ) - assert CameraDevice() & expected2 - assert (CameraDevice() & expected2).fetch("meters_per_pixel") == 0.002 - - x = Probe().fetch() - assert len(x) == 1 - assert x[0]["probe_type"] == "128c-4s8mm6cm-20um-40um-sl" - assert x[0]["probe_description"] == "128 channel polyimide probe" - assert x[0]["num_shanks"] == 4 - assert x[0]["contact_side_numbering"] == "True" diff --git a/tests/trim_beans.py b/tests/trim_beans.py deleted file mode 100644 index 242e65c49..000000000 --- a/tests/trim_beans.py +++ /dev/null @@ -1,73 +0,0 @@ -import pynwb - -# import ndx_franklab_novela - -file_in = "beans20190718.nwb" -file_out = "beans20190718_trimmed.nwb" - -n_timestamps_to_keep = 20 # / 20000 Hz sampling rate = 1 ms - -with pynwb.NWBHDF5IO(file_in, "r", load_namespaces=True) as io: - nwbfile = io.read() - orig_eseries = nwbfile.acquisition.pop("e-series") - - # create a new ElectricalSeries with a subset of the data and timestamps - data = orig_eseries.data[0:n_timestamps_to_keep, :] - ts = orig_eseries.timestamps[0:n_timestamps_to_keep] - electrodes = nwbfile.create_electrode_table_region( - region=orig_eseries.electrodes.data[:].tolist(), - name=orig_eseries.electrodes.name, - description=orig_eseries.electrodes.description, - ) - new_eseries = pynwb.ecephys.ElectricalSeries( - name=orig_eseries.name, - description=orig_eseries.description, - data=data, - timestamps=ts, - electrodes=electrodes, - ) - nwbfile.add_acquisition(new_eseries) - - # create a new analog TimeSeries with a subset of the data and timestamps - orig_analog = nwbfile.processing["analog"]["analog"].time_series.pop( - "analog" - ) - data = orig_analog.data[0:n_timestamps_to_keep, :] - ts = orig_analog.timestamps[0:n_timestamps_to_keep] - new_analog = pynwb.TimeSeries( - name=orig_analog.name, - description=orig_analog.description, - data=data, - timestamps=ts, - unit=orig_analog.unit, - ) - nwbfile.processing["analog"]["analog"].add_timeseries(new_analog) - - # remove last two columns of all SpatialSeries data (xloc2, yloc2) because - # it does not conform with NWB 2.5 and they are all zeroes anyway - new_spatial_series = list() - for spatial_series_name in list( - nwbfile.processing["behavior"]["position"].spatial_series - ): - spatial_series = nwbfile.processing["behavior"][ - "position" - ].spatial_series.pop(spatial_series_name) - assert isinstance(spatial_series, pynwb.behavior.SpatialSeries) - data = spatial_series.data[:, 0:2] - ts = spatial_series.timestamps[0:n_timestamps_to_keep] - new_spatial_series.append( - pynwb.behavior.SpatialSeries( - name=spatial_series.name, - description=spatial_series.description, - data=data, - timestamps=spatial_series.timestamps, - reference_frame=spatial_series.reference_frame, - ) - ) - for spatial_series in new_spatial_series: - nwbfile.processing["behavior"]["position"].add_spatial_series( - spatial_series - ) - - with pynwb.NWBHDF5IO(file_out, "w") as export_io: - export_io.export(io, nwbfile) diff --git a/tests/test_nwb_helper_fn.py b/tests/utils/test_nwb_helper_fn.py similarity index 86% rename from tests/test_nwb_helper_fn.py rename to tests/utils/test_nwb_helper_fn.py index ad382b0a4..d054f7ecb 100644 --- a/tests/test_nwb_helper_fn.py +++ b/tests/utils/test_nwb_helper_fn.py @@ -3,9 +3,11 @@ import pynwb -# NOTE: importing this calls spyglass.__init__ whichand spyglass.common.__init__ which both require the -# DataJoint MySQL server to be already set up and running -from spyglass.common import get_electrode_indices + +def get_electrode_indices(*args, **kwargs): + from spyglass.common import get_electrode_indices # noqa: E402 + + return get_electrode_indices(*args, **kwargs) class TestGetElectrodeIndices(unittest.TestCase): @@ -48,7 +50,7 @@ def setUp(self): ) self.nwbfile.add_acquisition(eseries) - def test_nwbfile(self): + def test_electrode_nwbfile(self): ret = get_electrode_indices(self.nwbfile, [102, 105]) assert ret == [2, 5]