Skip to content

Commit

Permalink
Pytest revamp (LorenFrankLab#743)
Browse files Browse the repository at this point in the history
* WIP: Pull from old stash, resolve conflicts

* Pytest WIP. Position centriod fix. Centralize device prompt logic

* Add tests for all tables in

* WIP: Improve coverage behav, dio

* WIP: Add coverage, see details:

- Add `return_fig` param to plotting helper functions to permit tests
  - `common_filter`
  - `common_interval`
- Add coverage for ~1/2 of `common`
  - `common_behav`
  - `common_device`
  - `common_ephys`
  - `common_filter`
  - `common_interval` - with helper funcs tested seperately
  - `common_lab`
  - `common_nwbfile` - partial

* WIP pytest common 2nd half, start lfp

* WIP lfp tests, ahead of fetch upstream

* Add lfp pipeline tests

* Run pre-commit checks

* Fix bug

* Unpin position_tools for CI

* Change download data dir

* Change download data dir 2

* Fix teardown. Coverage 67%

* Update changelog

* logger.warn -> logger.warning
  • Loading branch information
CBroz1 authored Jan 19, 2024
1 parent 6705ee0 commit 0089d5e
Show file tree
Hide file tree
Showing 46 changed files with 2,283 additions and 639 deletions.
21 changes: 11 additions & 10 deletions .github/workflows/test-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,6 @@ jobs:
env:
OS: ${{ matrix.os }}
PYTHON: '3.8'
# SPYGLASS_BASE_DIR: ./data
# KACHERY_STORAGE_DIR: ./data/kachery-storage
# DJ_SUPPORT_FILEPATH_MANAGEMENT: True
# services:
# datajoint_test_server:
# image: datajoint/mysql
# ports:
# - 3306:3306
# options: >-
# -e MYSQL_ROOT_PASSWORD=tutorial
steps:
- name: Cancel Workflow Action
uses: styfle/[email protected]
Expand All @@ -49,6 +39,17 @@ jobs:
- name: Install spyglass
run: |
pip install -e .[test]
- name: Download data
env:
UCSF_BOX_TOKEN: ${{ secrets.UCSF_BOX_TOKEN }}
UCSF_BOX_USER: ${{ secrets.UCSF_BOX_USER }}
WEBSITE: ftps://ftp.box.com/trodes_to_nwb_test_data/minirec20230622.nwb
RAW_DIR: /home/runner/work/spyglass/spyglass/tests/_data/raw/
run: |
mkdir -p $RAW_DIR
wget --recursive --no-verbose --no-host-directories --no-directories \
--user $UCSF_BOX_USER --password $UCSF_BOX_TOKEN \
-P $RAW_DIR $WEBSITE
- name: Run tests
run: |
pytest -rP # env vars are set within certain tests
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- Add `deprecation_factory` to facilitate table migration. #717
- Add Spyglass logger. #730
- IntervalList: Add secondary key `pipeline` #742
- Increase pytest coverage for `common`, `lfp`, and `utils`. #743

### Pipelines

Expand All @@ -31,7 +32,6 @@
- Allow multiple spike waveform features for clusterelss decoding #731
- Reorder notebooks #731


## [0.4.3] (November 7, 2023)

- Migrate `config` helper scripts to Spyglass codebase. #662
Expand Down
39 changes: 38 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ spyglass_cli = "spyglass.cli:cli"
[project.optional-dependencies]
position = ["ffmpeg", "numba>=0.54", "deeplabcut<2.3.0"]
test = [
"docker", # for tests in a container
"pytest", # unit testing
"pytest-cov", # code coverage
"kachery", # database access
Expand Down Expand Up @@ -109,5 +110,41 @@ line-length = 80

[tool.codespell]
skip = '.git,*.pdf,*.svg,*.ipynb,./docs/site/**,temp*'
# Nevers - name in Citation
ignore-words-list = 'nevers'
# Nevers - name in Citation

[tool.pytest.ini_options]
minversion = "7.0"
addopts = [
"-sv",
"-p no:warnings",
# "--no-teardown", # don't teardown the database after tests
# "--quiet-spy", # don't show logging from spyglass
"--show-capture=no",
"--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger
"--cov=spyglass",
"--cov-report=term-missing",
"--no-cov-on-fail",
]
testpaths = ["tests"]
log_level = "INFO"

[tool.coverage.run]
source = ["*/src/spyglass/*"]
omit = [ # which submodules have no tests
"*/__init__.py",
"*/_version.py",
"*/cli/*",
# "*/common/*",
"*/data_import/*",
"*/decoding/*",
"*/figurl_views/*",
# "*/lfp/*",
"*/linearization/*",
"*/lock/*",
"*/position/*",
"*/ripple/*",
"*/sharing/*",
"*/spikesorting/*",
# "*/utils/*",
]
158 changes: 70 additions & 88 deletions src/spyglass/common/common_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import ndx_franklab_novela

from spyglass.common.errors import PopulateException
from spyglass.utils.dj_mixin import SpyglassMixin
from spyglass.utils.logging import logger
from spyglass.settings import test_mode
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils.nwb_helper_fn import get_nwb_file

schema = dj.schema("common_device")
Expand Down Expand Up @@ -154,25 +154,9 @@ def _add_device(cls, new_device_dict):
all_values = DataAcquisitionDevice.fetch(
"data_acquisition_device_name"
).tolist()
if name not in all_values:
# no entry with the same name exists, prompt user to add a new entry
logger.info(
f"\nData acquisition device '{name}' was not found in the "
f"database. The current values are: {all_values}. "
"Please ensure that the device you want to add does not already"
" exist in the database under a different name or spelling. "
"If you want to use an existing device in the database, "
"please change the corresponding Device object in the NWB file."
" Entering 'N' will raise an exception."
)
to_db = " to the database"
val = input(f"Add data acquisition device '{name}'{to_db}? (y/N)")
if val.lower() in ["y", "yes"]:
cls.insert1(new_device_dict, skip_duplicates=True)
return
raise PopulateException(
f"User chose not to add device '{name}'{to_db}."
)
if prompt_insert(name=name, all_values=all_values):
cls.insert1(new_device_dict, skip_duplicates=True)
return

# Check if values provided match the values stored in the database
db_dict = (
Expand Down Expand Up @@ -213,28 +197,11 @@ def _add_system(cls, system):
all_values = DataAcquisitionDeviceSystem.fetch(
"data_acquisition_device_system"
).tolist()
if system not in all_values:
logger.info(
f"\nData acquisition device system '{system}' was not found in"
f" the database. The current values are: {all_values}. "
"Please ensure that the system you want to add does not already"
" exist in the database under a different name or spelling. "
"If you want to use an existing system in the database, "
"please change the corresponding Device object in the NWB file."
" Entering 'N' will raise an exception."
)
val = input(
f"Do you want to add data acquisition device system '{system}'"
+ " to the database? (y/N)"
)
if val.lower() in ["y", "yes"]:
key = {"data_acquisition_device_system": system}
DataAcquisitionDeviceSystem.insert1(key, skip_duplicates=True)
else:
raise PopulateException(
"User chose not to add data acquisition device system "
+ f"'{system}' to the database."
)
if prompt_insert(
name=system, all_values=all_values, table_type="system"
):
key = {"data_acquisition_device_system": system}
DataAcquisitionDeviceSystem.insert1(key, skip_duplicates=True)
return system

@classmethod
Expand Down Expand Up @@ -264,30 +231,11 @@ def _add_amplifier(cls, amplifier):
all_values = DataAcquisitionDeviceAmplifier.fetch(
"data_acquisition_device_amplifier"
).tolist()
if amplifier not in all_values:
logger.info(
f"\nData acquisition device amplifier '{amplifier}' was not "
f"found in the database. The current values are: {all_values}. "
"Please ensure that the amplifier you want to add does not "
"already exist in the database under a different name or "
"spelling. If you want to use an existing name in the database,"
" please change the corresponding Device object in the NWB "
"file. Entering 'N' will raise an exception."
)
val = input(
"Do you want to add data acquisition device amplifier "
+ f"'{amplifier}' to the database? (y/N)"
)
if val.lower() in ["y", "yes"]:
key = {"data_acquisition_device_amplifier": amplifier}
DataAcquisitionDeviceAmplifier.insert1(
key, skip_duplicates=True
)
else:
raise PopulateException(
"User chose not to add data acquisition device amplifier "
+ f"'{amplifier}' to the database."
)
if prompt_insert(
name=amplifier, all_values=all_values, table_type="amplifier"
):
key = {"data_acquisition_device_amplifier": amplifier}
DataAcquisitionDeviceAmplifier.insert1(key, skip_duplicates=True)
return amplifier


Expand Down Expand Up @@ -576,27 +524,9 @@ def _add_probe_type(cls, new_probe_type_dict):
"""
probe_type = new_probe_type_dict["probe_type"]
all_values = ProbeType.fetch("probe_type").tolist()
if probe_type not in all_values:
logger.info(
f"\nProbe type '{probe_type}' was not found in the database. "
f"The current values are: {all_values}. "
"Please ensure that the probe type you want to add does not "
"already exist in the database under a different name or "
"spelling. If you want to use an existing name in the "
"database, please change the corresponding Probe object in the "
"NWB file. Entering 'N' will raise an exception."
)
val = input(
f"Do you want to add probe type '{probe_type}' to the database?"
+ " (y/N)"
)
if val.lower() in ["y", "yes"]:
ProbeType.insert1(new_probe_type_dict, skip_duplicates=True)
return
raise PopulateException(
f"User chose not to add probe type '{probe_type}' to the "
+ "database."
)
if prompt_insert(probe_type, all_values, table="probe type"):
ProbeType.insert1(new_probe_type_dict, skip_duplicates=True)
return

# else / entry exists: check whether the values provided match the
# values stored in the database
Expand Down Expand Up @@ -738,3 +668,55 @@ def create_from_nwbfile(
cls.Shank.insert1(shank, skip_duplicates=True)
for electrode in elect_dict.values():
cls.Electrode.insert1(electrode, skip_duplicates=True)


# ---------------------------- Helper functions ----------------------------


# Migrated down to reduce redundancy and centralize 'test_mode' check for pytest
def prompt_insert(
name: str,
all_values: list,
table: str = "Data Acquisition Device",
table_type: str = None,
) -> bool:
"""Prompt user to add an item to the database. Return True if yes.
Assume insert during test mode.
Parameters
----------
name : str
The name of the item to add.
all_values : list
List of all values in the database.
table : str, optional
The name of the table to add to, by default Data Acquisition Device
table_type : str, optional
The type of item to add, by default None. Data Acquisition Device X
"""
if name in all_values:
return False

if test_mode:
return True

if table_type:
table_type += " "

logger.info(
f"{table}{table_type} '{name}' was not found in the"
f"database. The current values are: {all_values}.\n"
"Please ensure that the device you want to add does not already"
"exist in the database under a different name or spelling. If you"
"want to use an existing device in the database, please change the"
"corresponding Device object in the NWB file.\nEntering 'N' will "
"raise an exception."
)
msg = f"Do you want to add {table}{table_type} '{name}' to the database?"
if dj.utils.user_choice(msg).lower() in ["y", "yes"]:
return True

raise PopulateException(
f"User chose not to add {table}{table_type} '{name}' to the database."
)
5 changes: 4 additions & 1 deletion src/spyglass/common/common_dio.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def make(self, key):
key["dio_object_id"] = event_series.object_id
self.insert1(key, skip_duplicates=True)

def plot_all_dio_events(self):
def plot_all_dio_events(self, return_fig=False):
"""Plot all DIO events in the session.
Examples
Expand Down Expand Up @@ -117,3 +117,6 @@ def plot_all_dio_events(self):
plt.suptitle(f"DIO events in {nwb_file_names[0]}")
else:
plt.suptitle(f"DIO events in {', '.join(nwb_file_names)}")

if return_fig:
return plt.gcf()
12 changes: 8 additions & 4 deletions src/spyglass/common/common_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def add_filter(
def _filter_restrict(self, filter_name, fs):
return (
self & {"filter_name": filter_name} & {"filter_sampling_rate": fs}
).fetch1(as_dict=True)
).fetch1()

def plot_magnitude(self, filter_name, fs):
def plot_magnitude(self, filter_name, fs, return_fig=False):
filter_dict = self._filter_restrict(filter_name, fs)
plt.figure()
w, h = signal.freqz(filter_dict["filter_coeff"], worN=65536)
Expand All @@ -178,11 +178,13 @@ def plot_magnitude(self, filter_name, fs):
plt.xlabel("Frequency (Hz)")
plt.ylabel("Magnitude")
plt.title("Frequency Response")
plt.xlim(0, np.max(filter_dict["filter_coeffand_edges"] * 2))
plt.xlim(0, np.max(filter_dict["filter_band_edges"] * 2))
plt.ylim(np.min(magnitude), -1 * np.min(magnitude) * 0.1)
plt.grid(True)
if return_fig:
return plt.gcf()

def plot_fir_filter(self, filter_name, fs):
def plot_fir_filter(self, filter_name, fs, return_fig=False):
filter_dict = self._filter_restrict(filter_name, fs)
plt.figure()
plt.clf()
Expand All @@ -191,6 +193,8 @@ def plot_fir_filter(self, filter_name, fs):
plt.ylabel("Magnitude")
plt.title("Filter Taps")
plt.grid(True)
if return_fig:
return plt.gcf()

def filter_delay(self, filter_name, fs):
return self.calc_filter_delay(
Expand Down
8 changes: 6 additions & 2 deletions src/spyglass/common/common_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def insert_from_nwbfile(cls, nwbf, *, nwb_file_name):

cls.insert1(epoch_dict, skip_duplicates=True)

def plot_intervals(self, figsize=(20, 5)):
def plot_intervals(self, figsize=(20, 5), return_fig=False):
interval_list = pd.DataFrame(self)
fig, ax = plt.subplots(figsize=figsize)
interval_count = 0
Expand All @@ -84,8 +84,10 @@ def plot_intervals(self, figsize=(20, 5)):
ax.set_yticklabels(interval_list.interval_list_name)
ax.set_xlabel("Time [s]")
ax.grid(True)
if return_fig:
return fig

def plot_epoch_pos_raw_intervals(self, figsize=(20, 5)):
def plot_epoch_pos_raw_intervals(self, figsize=(20, 5), return_fig=False):
interval_list = pd.DataFrame(self)
fig, ax = plt.subplots(figsize=(30, 3))

Expand Down Expand Up @@ -145,6 +147,8 @@ def plot_epoch_pos_raw_intervals(self, figsize=(20, 5)):
ax.set_yticklabels(["pos valid times", "raw data valid times", "epoch"])
ax.set_xlabel("Time [s]")
ax.grid(True)
if return_fig:
return fig


def intervals_by_length(interval_list, min_length=0.0, max_length=1e10):
Expand Down
Loading

0 comments on commit 0089d5e

Please sign in to comment.