Skip to content

Commit

Permalink
WIP: Subpackage coverage 72%
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed May 20, 2024
1 parent 944de85 commit a46b1a5
Show file tree
Hide file tree
Showing 18 changed files with 335 additions and 133 deletions.
17 changes: 12 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ test = [
"pre-commit", # linting
"pytest", # unit testing
"pytest-cov", # code coverage
"pytest-xvfb", # for headless testing of Qt
]
docs = [
"hatch", # Get version from env
Expand Down Expand Up @@ -120,12 +121,12 @@ ignore-words-list = 'nevers'
[tool.pytest.ini_options]
minversion = "7.0"
addopts = [
"-sv",
"--sw", # stepwise: resume with next test after failure
"--pdb", # drop into debugger on failure
# "-sv", # verbose output
# "--sw", # stepwise: resume with next test after failure
# "--pdb", # drop into debugger on failure
"-p no:warnings",
"--no-teardown", # don't teardown the database after tests
"--quiet-spy", # don't show logging from spyglass
# "--no-teardown", # don't teardown the database after tests
# "--quiet-spy", # don't show logging from spyglass
"--show-capture=no",
"--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger
"--cov=spyglass",
Expand All @@ -134,6 +135,12 @@ addopts = [
]
testpaths = ["tests"]
log_level = "INFO"
env = [
"QT_QPA_PLATFORM = offscreen", # QT fails headless without this
# "DISPLAY = :0", # QT fails headless without this
"TF_ENABLE_ONEDNN_OPTS = 0", # TF disable approx calcs
"TF_CPP_MIN_LOG_LEVEL = 2", # Disable TF warnings
]

[tool.coverage.run]
source = ["*/src/spyglass/*"]
Expand Down
15 changes: 15 additions & 0 deletions src/spyglass/decoding/v0/dj_decoder_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@
ObservationModel,
)
except ImportError as e:
(
Identity,
RandomWalk,
RandomWalkDirection1,
RandomWalkDirection2,
Uniform,
DiagonalDiscrete,
RandomDiscrete,
UniformDiscrete,
UserDefinedDiscrete,
Environment,
UniformInitialConditions,
UniformOneEnvironmentInitialConditions,
ObservationModel,
) = [None] * 13
logger.warning(e)
from track_linearization import make_track_graph

Expand Down
17 changes: 15 additions & 2 deletions src/spyglass/position/v1/dlc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,17 @@ def infer_output_dir(key, makedir=True):
"""
# TODO: add check to make sure interval_list_name refers to a single epoch
# Or make key include epoch in and of itself instead of interval_list_name
nwb_file_name = key["nwb_file_name"].split("_.")[0]

file_name = key.get("nwb_file_name")
dlc_model_name = key.get("dlc_model_name")
epoch = key.get("epoch")

if not all([file_name, dlc_model_name, epoch]):
raise ValueError(
"Key must contain 'nwb_file_name', 'dlc_model_name', and 'epoch'"
)

nwb_file_name = file_name.split("_.")[0]
output_dir = pathlib.Path(dlc_output_dir) / pathlib.Path(
f"{nwb_file_name}/{nwb_file_name}_{key['epoch']:02}"
f"_model_" + key["dlc_model_name"].replace(" ", "-")
Expand Down Expand Up @@ -1021,7 +1031,10 @@ def make_video(
video.release()
out.release()
print("destroying cv2 windows")
cv2.destroyAllWindows()
try:
cv2.destroyAllWindows()
except cv2.error: # if cv is already closed or does not have func
pass
print("finished making video with opencv")
return

Expand Down
12 changes: 8 additions & 4 deletions src/spyglass/position/v1/position_dlc_pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class DLCPoseEstimationSelection(SpyglassMixin, dj.Manual):
"""

@classmethod
def get_video_crop(cls, video_path):
def get_video_crop(cls, video_path, crop_input=None):
"""
Queries the user to determine the cropping parameters for a given video
Expand All @@ -61,9 +61,13 @@ def get_video_crop(cls, video_path):
ax.set_yticks(np.arange(ylims[0], ylims[-1], -50))
ax.grid(visible=True, color="white", lw=0.5, alpha=0.5)
display(fig)
crop_input = input(
"Please enter the crop parameters for your video in format xmin, xmax, ymin, ymax, or 'none'\n"
)

if crop_input is None:
crop_input = input(
"Please enter the crop parameters for your video in format "
+ "xmin, xmax, ymin, ymax, or 'none'\n"
)

plt.close()
if crop_input.lower() == "none":
return None
Expand Down
3 changes: 3 additions & 0 deletions src/spyglass/position/v1/position_dlc_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ class DLCPosVideo(SpyglassMixin, dj.Computed):
---
"""

# TODO: Shoultn't this keep track of the video file it creates?

def make(self, key):
from tqdm import tqdm as tqdm

Expand Down Expand Up @@ -432,3 +434,4 @@ def make(self, key):
crop=crop,
**params["video_params"],
)
self.insert1(key)
2 changes: 1 addition & 1 deletion src/spyglass/position/v1/position_dlc_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class DLCModelTrainingSelection(SpyglassMixin, dj.Manual):
"""

def insert1(self, key, **kwargs):
training_id = key["training_id"]
training_id = key.get("training_id")
if training_id is None:
training_id = (
dj.U().aggr(self & key, n="max(training_id)").fetch1("n") or 0
Expand Down
16 changes: 15 additions & 1 deletion tests/README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
# PyTests

## Environment

To facilitate headless testing of various Qt-based tools as well as Tensorflow,
`pyproject.toml` includes some environment variables associated with the
display. These are...

- `QT_QPA_PLATFORM`: Set to `offscreen` to prevent the need for a display.
- `TF_ENABLE_ONEDNN_OPTS`: Set to `1` to enable Tensorflow optimizations.
- `TF_CPP_MIN_LOG_LEVEL`: Set to `2` to suppress Tensorflow warnings.

<!-- - `DISPLAY`: Set to `:0` to prevent the need for a display. -->

## Options

This directory is contains files for testing the code. Simply by running
`pytest` from the root directory, all tests will be run with default parameters
specified in `pyproject.toml`. Notable optional parameters include...

- Coverage items. The coverage report indicates what percentage of the code was
included in tests.

- `--cov=spyglatss`: Which package should be described in the coverage report
- `--cov=spyglass`: Which package should be described in the coverage report
- `--cov-report term-missing`: Include lines of items missing in coverage

- Verbosity.
Expand Down
5 changes: 0 additions & 5 deletions tests/common/test_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,6 @@ def test_populate_state_script(common, pop_state_script):
), "StateScript populate unexpected effect"


@pytest.fixture(scope="session")
def video_keys(common):
return common.VideoFile().fetch(as_dict=True)


@pytest.mark.usefixtures("skipif_noextras")
def test_videofile_update_entries(common, video_keys):
"""Test update entries"""
Expand Down
56 changes: 42 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,14 @@ def populate_exception():
yield PopulateException


# -------------------------- FIXTURES, COMMON TABLES --------------------------


@pytest.fixture(scope="session")
def video_keys(common):
return common.VideoFile().fetch(as_dict=True)


# ------------------------- FIXTURES, POSITION TABLES -------------------------


Expand Down Expand Up @@ -428,7 +436,7 @@ def trodes_params(trodes_params_table, teardown):
},
},
}
trodes_params_table.get_default()
_ = trodes_params_table.get_default()
trodes_params_table.insert(
[v for k, v in paramsets.items()], skip_duplicates=True
)
Expand Down Expand Up @@ -778,10 +786,16 @@ def dlc_project_tbl(sgp):
yield sgp.v1.DLCProject()


@pytest.fixture(scope="session")
def dlc_project_name():
yield "pytest_proj"


@pytest.fixture(scope="session")
def insert_project(
verbose_context,
teardown,
dlc_project_name,
dlc_project_tbl,
common,
bodyparts,
Expand All @@ -791,7 +805,7 @@ def insert_project(
common.LabTeam.insert1({"team_name": team_name}, skip_duplicates=True)
with verbose_context:
project_key = dlc_project_tbl.insert_new_project(
project_name="pytest_proj",
project_name=dlc_project_name,
bodyparts=bodyparts,
lab_team=team_name,
frames_per_video=100,
Expand Down Expand Up @@ -858,8 +872,14 @@ def extract_frames(
)
vid_name = list(dlc_config["video_sets"].keys())[0].split("/")[-1]
label_dir = project_dir / "labeled-data" / vid_name.split(".")[0]

yield label_dir

for file in label_dir.glob("*png"):
if file.stem in ["img000", "img001"]:
continue
file.unlink()


@pytest.fixture(scope="session")
def labeled_vid_dir(extract_frames):
Expand Down Expand Up @@ -889,22 +909,30 @@ def add_training_files(dlc_project_tbl, project_key, fix_downloaded):


@pytest.fixture(scope="session")
def training_params_key(verbose_context, sgp, project_key):
training_params_name = "pytest"
def dlc_training_params(sgp):
params_tbl = sgp.v1.DLCModelTrainingParams()
params_name = "pytest"
yield params_tbl, params_name


@pytest.fixture(scope="session")
def training_params_key(verbose_context, sgp, project_key, dlc_training_params):
params_tbl, params_name = dlc_training_params
with verbose_context:
sgp.v1.DLCModelTrainingParams.insert_new_params(
paramset_name=training_params_name,
params_tbl.insert_new_params(
paramset_name=params_name,
params={
"trainingsetindex": 0,
"shuffle": 1,
"gputouse": None,
"TFGPUinference": False,
"net_type": "resnet_50",
"augmenter_type": "imgaug",
"video_sets": "test skipping param",
},
skip_duplicates=True,
)
yield {"dlc_training_params_name": training_params_name}
yield {"dlc_training_params_name": params_name}


@pytest.fixture(scope="session")
Expand All @@ -913,7 +941,6 @@ def model_train_key(sgp, project_key, training_params_key):
model_train_key = {
**project_key,
**training_params_key,
"training_id": 0,
}
sgp.v1.DLCModelTrainingSelection().insert1(
{
Expand Down Expand Up @@ -974,19 +1001,17 @@ def pose_estimation_key(sgp, mini_copy_name, populate_model, model_key):

@pytest.fixture(scope="session")
def populate_pose_estimation(sgp, pose_estimation_key):
pose_est_tbl = sgp.v1.DLCPoseEstimation
if pose_est_tbl & pose_estimation_key:
yield
else:
pose_est_tbl = sgp.v1.DLCPoseEstimation()
if len(pose_est_tbl & pose_estimation_key) < 1:
pose_est_tbl.populate(pose_estimation_key)
yield
yield pose_est_tbl


@pytest.fixture(scope="session")
def si_params_name(sgp, populate_pose_estimation):
params_name = "low_bar"
params_tbl = sgp.v1.DLCSmoothInterpParams
# if len(params_tbl & {"dlc_si_params_name": params_name}) == 0:
# if len(params_tbl & {"dlc_si_params_name": params_name}) < 1:
if True: # TODO: remove before merge
nan_params = params_tbl.get_nan_params()
nan_params["dlc_si_params_name"] = params_name
Expand All @@ -995,6 +1020,9 @@ def si_params_name(sgp, populate_pose_estimation):
"likelihood_thresh": 0.4,
"max_cm_between_pts": 100,
"num_inds_to_span": 50,
# Smoothing and Interpolation added later - must check
"smoothing_params": {"smoothing_duration": 0.05},
"interp_params": {"max_cm_to_interp": 100},
}
)
params_tbl.insert1(nan_params, skip_duplicates=True)
Expand Down
17 changes: 4 additions & 13 deletions tests/position/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,6 @@
58, 61, 69, 72, 97-100, 104, 149-161, 232-235, 239-241, 246, 259, 280,
293-305, 310-316, 328-341, 356-373, 395, 404, 480, 487-488, 530, 548-561,
594-601, 611-612, 641-657, 682-736, 762-772, 787, 809-1286
TODO: tests for
pose_estimat 51-71, 102, 115, 256, 345-366
position.py 53, 99, 114, 119, 197-198, 205-219, 349, 353-355, 360, 382, 385, 407, 443-466
project.py 45-54, 128-205, 250-255, 259, 278-303, 316, 347, 361-413, 425, 457, 476-479, 486-489, 514-555, 582, 596
selection.py 213, 282, 308-417
training.py 55, 67-73, 85-87, 113, 143-144, 161, 207-210
es_position.py 67, 282-283, 361-362, 496, 502-503
"""

from itertools import product as iter_prodect
Expand All @@ -33,7 +24,7 @@
def dlc_video_params(sgp):
sgp.v1.DLCPosVideoParams.insert_default()
params_key = {"dlc_pos_video_params_name": "five_percent"}
sgp.v1.DLCPosVideoSelection.insert1(
sgp.v1.DLCPosVideoParams.insert1(
{
**params_key,
"params": {
Expand All @@ -47,7 +38,7 @@ def dlc_video_params(sgp):


@pytest.fixture(scope="session")
def dlc_video_selection(sgp, dlc_key, dlc_video_params):
def dlc_video_selection(sgp, dlc_key, dlc_video_params, populate_dlc):
s_key = {**dlc_key, **dlc_video_params}
sgp.v1.DLCPosVideoSelection.insert1(s_key, skip_duplicates=True)
yield dlc_key
Expand All @@ -56,7 +47,7 @@ def dlc_video_selection(sgp, dlc_key, dlc_video_params):
@pytest.fixture(scope="session")
def populate_dlc_video(sgp, dlc_video_selection):
sgp.v1.DLCPosVideo.populate(dlc_video_selection)
yield
yield sgp.v1.DLCPosVideo()


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -95,4 +86,4 @@ def increment_count():
count[0] += 1
return count[0]

return df.map(lambda x: increment_count() if x == 1 else x)
return df.applymap(lambda x: increment_count() if x == 1 else x)
Loading

0 comments on commit a46b1a5

Please sign in to comment.