diff --git a/src/spyglass/position/v1/position_dlc_centroid.py b/src/spyglass/position/v1/position_dlc_centroid.py index f1f077d6a..70a1c1252 100644 --- a/src/spyglass/position/v1/position_dlc_centroid.py +++ b/src/spyglass/position/v1/position_dlc_centroid.py @@ -170,7 +170,7 @@ def make(self, key): for point in required_points: bodypart = points[point] if bodypart not in bodyparts_avail: - raise ValueError( + raise ValueError( # TODO: migrate to input validation "Bodypart in points not in model." f"\tBodypart {bodypart}" f"\tIn Model {bodyparts_avail}" @@ -222,6 +222,7 @@ def make(self, key): "smoothing_duration" ) if not smoothing_duration: + # TODO: remove - validated with `validate_smooth_params` raise KeyError( "smoothing_duration needs to be passed within smoothing_params" ) @@ -368,6 +369,7 @@ def four_led_centroid(pos_df: pd.DataFrame, **params): """Determines the centroid of 4 LEDS on an implant LED ring. Assumed to be the Green LED, and 3 red LEDs called: redLED_C, redLED_L, redLED_R By default, uses (greenled + redLED_C) / 2 to calculate centroid + If Green LED is NaN, but red center LED is not, then the red center LED is called the centroid If green and red center LEDs are NaN, but red left and red right LEDs are not, @@ -397,6 +399,9 @@ def four_led_centroid(pos_df: pd.DataFrame, **params): numpy array with shape (n_time, 2) centroid[0] is the x coord and centroid[1] is the y coord """ + if not (params.get("max_LED_separation") and params.get("points")): + raise KeyError("max_LED_separation/points need to be passed in params") + centroid = np.zeros(shape=(len(pos_df), 2)) idx = pd.IndexSlice # TODO: this feels messy, clean-up @@ -722,6 +727,8 @@ def two_pt_centroid(pos_df: pd.DataFrame, **params): numpy array with shape (n_time, 2) centroid[0] is the x coord and centroid[1] is the y coord """ + if not (params.get("max_LED_separation") and params.get("points")): + raise KeyError("max_LED_separation/points need to be passed in params") idx = pd.IndexSlice centroid = np.zeros(shape=(len(pos_df), 2)) @@ -797,6 +804,8 @@ def one_pt_centroid(pos_df: pd.DataFrame, **params): numpy array with shape (n_time, 2) centroid[0] is the x coord and centroid[1] is the y coord """ + if not params.get("points"): + raise KeyError("points need to be passed in params") idx = pd.IndexSlice PT1 = params["points"].pop("point1", None) centroid = pos_df.loc[:, idx[PT1, ("x", "y")]].to_numpy() diff --git a/src/spyglass/position/v1/position_dlc_position.py b/src/spyglass/position/v1/position_dlc_position.py index eafe0ada0..c18eafd62 100644 --- a/src/spyglass/position/v1/position_dlc_position.py +++ b/src/spyglass/position/v1/position_dlc_position.py @@ -180,7 +180,7 @@ def make(self, key): bp_key = key.copy() if test_mode: # during testing, analysis_file not in BodyPart table - bp_key.pop("analysis_file_name") + bp_key.pop("analysis_file_name", None) dlc_df = (DLCPoseEstimation.BodyPart() & bp_key).fetch1_dataframe() dt = np.median(np.diff(dlc_df.index.to_numpy())) diff --git a/tests/conftest.py b/tests/conftest.py index 918e9c0c5..2e2f5633b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,7 @@ import warnings from contextlib import nullcontext from pathlib import Path +from shutil import rmtree as shutil_rmtree from time import sleep as tsleep import datajoint as dj @@ -17,6 +18,7 @@ import pynwb import pytest from datajoint.logging import logger as dj_logger +from deeplabcut.utils.auxiliaryfunctions import read_config, write_config from numba import NumbaWarning from pandas.errors import PerformanceWarning @@ -754,3 +756,428 @@ def lfp_merge_key(populate_lfp): @pytest.fixture(scope="session") def lfp_v1_key(lfp, lfp_s_key): yield (lfp.v1.LFPV1 & lfp_s_key).fetch1("KEY") + + +# --------------------------- FIXTURES, DLC TABLES ---------------------------- +# ---------------- Note: DLCOutput is used to test RestrGraph ----------------- + + +@pytest.fixture(scope="session") +def bodyparts(sgp): + bps = ["whiteLED", "tailBase", "tailMid", "tailTip"] + sgp.v1.BodyPart.insert( + [{"bodypart": bp, "bodypart_description": "none"} for bp in bps], + skip_duplicates=True, + ) + + yield bps + + +@pytest.fixture(scope="session") +def dlc_project_tbl(sgp): + yield sgp.v1.DLCProject() + + +@pytest.fixture(scope="session") +def insert_project( + verbose_context, + teardown, + dlc_project_tbl, + common, + bodyparts, + mini_copy_name, +): + team_name = "sc_eb" + common.LabTeam.insert1({"team_name": team_name}, skip_duplicates=True) + with verbose_context: + project_key = dlc_project_tbl.insert_new_project( + project_name="pytest_proj", + bodyparts=bodyparts, + lab_team=team_name, + frames_per_video=100, + video_list=[ + {"nwb_file_name": mini_copy_name, "epoch": 0}, + {"nwb_file_name": mini_copy_name, "epoch": 1}, + ], + skip_duplicates=True, + ) + config_path = (dlc_project_tbl & project_key).fetch1("config_path") + cfg = read_config(config_path) + cfg.update( + { + "numframes2pick": 2, + "maxiters": 2, + "scorer": team_name, + "skeleton": [ + ["whiteLED"], + [ + ["tailMid", "tailMid"], + ["tailBase", "tailBase"], + ["tailTip", "tailTip"], + ], + ], # eb's has video_sets: {1: {'crop': [0, 1260, 0, 728]}} + } + ) + + write_config(config_path, cfg) + + yield project_key, cfg, config_path + + if teardown: + (dlc_project_tbl & project_key).delete(safemode=False) + shutil_rmtree(str(Path(config_path).parent)) + + +@pytest.fixture(scope="session") +def project_key(insert_project): + yield insert_project[0] + + +@pytest.fixture(scope="session") +def dlc_config(insert_project): + yield insert_project[1] + + +@pytest.fixture(scope="session") +def config_path(insert_project): + yield insert_project[2] + + +@pytest.fixture(scope="session") +def project_dir(config_path): + yield Path(config_path).parent + + +@pytest.fixture(scope="session") +def extract_frames( + verbose_context, dlc_project_tbl, project_key, dlc_config, project_dir +): + with verbose_context: + dlc_project_tbl.run_extract_frames( + project_key, userfeedback=False, mode="automatic" + ) + vid_name = list(dlc_config["video_sets"].keys())[0].split("/")[-1] + label_dir = project_dir / "labeled-data" / vid_name.split(".")[0] + yield label_dir + + +@pytest.fixture(scope="session") +def labeled_vid_dir(extract_frames): + yield extract_frames + + +@pytest.fixture(scope="session") +def fix_downloaded(labeled_vid_dir, project_dir): + """Grabs CollectedData and img files from project_dir, moves to labeled""" + for file in project_dir.parent.parent.glob("*"): + if file.is_dir(): + continue + dest = labeled_vid_dir / file.name + if dest.exists(): + dest.unlink() + dest.write_bytes(file.read_bytes()) + # TODO: revert to rename before merge + # file.rename(labeled_vid_dir / file.name) + + yield + + +@pytest.fixture(scope="session") +def add_training_files(dlc_project_tbl, project_key, fix_downloaded): + dlc_project_tbl.add_training_files(project_key, skip_duplicates=True) + yield + + +@pytest.fixture(scope="session") +def training_params_key(verbose_context, sgp, project_key): + training_params_name = "pytest" + with verbose_context: + sgp.v1.DLCModelTrainingParams.insert_new_params( + paramset_name=training_params_name, + params={ + "trainingsetindex": 0, + "shuffle": 1, + "gputouse": None, + "TFGPUinference": False, + "net_type": "resnet_50", + "augmenter_type": "imgaug", + }, + skip_duplicates=True, + ) + yield {"dlc_training_params_name": training_params_name} + + +@pytest.fixture(scope="session") +def model_train_key(sgp, project_key, training_params_key): + _ = project_key.pop("config_path", None) + model_train_key = { + **project_key, + **training_params_key, + "training_id": 0, + } + sgp.v1.DLCModelTrainingSelection().insert1( + { + **model_train_key, + "model_prefix": "", + }, + skip_duplicates=True, + ) + yield model_train_key + + +@pytest.fixture(scope="session") +def populate_training(sgp, fix_downloaded, model_train_key, add_training_files): + train_tbl = sgp.v1.DLCModelTraining + if len(train_tbl & model_train_key) == 0: + _ = add_training_files + _ = fix_downloaded + sgp.v1.DLCModelTraining.populate(model_train_key) + yield model_train_key + + +@pytest.fixture(scope="session") +def model_source_key(sgp, model_train_key, populate_training): + yield (sgp.v1.DLCModelSource & model_train_key).fetch1("KEY") + + +@pytest.fixture(scope="session") +def model_key(sgp, model_source_key): + model_key = {**model_source_key, "dlc_model_params_name": "default"} + _ = sgp.v1.DLCModelParams.get_default() + sgp.v1.DLCModelSelection().insert1(model_key, skip_duplicates=True) + yield model_key + + +@pytest.fixture(scope="session") +def populate_model(sgp, model_key): + model_tbl = sgp.v1.DLCModel + if model_tbl & model_key: + yield + else: + sgp.v1.DLCModel.populate(model_key) + yield + + +@pytest.fixture(scope="session") +def pose_estimation_key(sgp, mini_copy_name, populate_model, model_key): + yield sgp.v1.DLCPoseEstimationSelection.insert_estimation_task( + { + "nwb_file_name": mini_copy_name, + "epoch": 1, + "video_file_num": 0, + **model_key, + }, + task_mode="trigger", # trigger or load + params={"gputouse": None, "videotype": "mp4", "TFGPUinference": False}, + ) + + +@pytest.fixture(scope="session") +def populate_pose_estimation(sgp, pose_estimation_key): + pose_est_tbl = sgp.v1.DLCPoseEstimation + if pose_est_tbl & pose_estimation_key: + yield + else: + pose_est_tbl.populate(pose_estimation_key) + yield + + +@pytest.fixture(scope="session") +def si_params_name(sgp, populate_pose_estimation): + params_name = "low_bar" + params_tbl = sgp.v1.DLCSmoothInterpParams + # if len(params_tbl & {"dlc_si_params_name": params_name}) == 0: + if True: # TODO: remove before merge + nan_params = params_tbl.get_nan_params() + nan_params["dlc_si_params_name"] = params_name + nan_params["params"].update( + { + "likelihood_thresh": 0.4, + "max_cm_between_pts": 100, + "num_inds_to_span": 50, + } + ) + params_tbl.insert1(nan_params, skip_duplicates=True) + + yield params_name + + +@pytest.fixture(scope="session") +def si_key(sgp, bodyparts, si_params_name, pose_estimation_key): + key = { + key: val + for key, val in pose_estimation_key.items() + if key in sgp.v1.DLCSmoothInterpSelection.primary_key + } + sgp.v1.DLCSmoothInterpSelection.insert( + [ + { + **key, + "bodypart": bodypart, + "dlc_si_params_name": si_params_name, + } + for bodypart in bodyparts[:1] + ], + skip_duplicates=True, + ) + yield key + + +@pytest.fixture(scope="session") +def populate_si(sgp, si_key, populate_pose_estimation): + sgp.v1.DLCSmoothInterp.populate() + yield + + +@pytest.fixture(scope="session") +def cohort_selection(sgp, si_key, si_params_name): + cohort_key = { + k: v + for k, v in { + **si_key, + "dlc_si_cohort_selection_name": "whiteLED", + "bodyparts_params_dict": { + "whiteLED": si_params_name, + }, + }.items() + if k not in ["bodypart", "dlc_si_params_name"] + } + sgp.v1.DLCSmoothInterpCohortSelection().insert1( + cohort_key, skip_duplicates=True + ) + yield cohort_key + + +@pytest.fixture(scope="session") +def cohort_key(sgp, cohort_selection): + yield cohort_selection.copy() + + +@pytest.fixture(scope="session") +def populate_cohort(sgp, cohort_selection, populate_si): + sgp.v1.DLCSmoothInterpCohort.populate(cohort_selection) + + +@pytest.fixture(scope="session") +def centroid_params(sgp): + params_tbl = sgp.v1.DLCCentroidParams + params_key = {"dlc_centroid_params_name": "one_test"} + if len(params_tbl & params_key) == 0: + params_tbl.insert1( + { + **params_key, + "params": { + "centroid_method": "one_pt_centroid", + "points": {"point1": "whiteLED"}, + "interpolate": True, + "interp_params": {"max_cm_to_interp": 100}, + "smooth": True, + "smoothing_params": { + "smoothing_duration": 0.05, + "smooth_method": "moving_avg", + }, + "max_LED_separation": 50, + "speed_smoothing_std_dev": 0.100, + }, + } + ) + yield params_key + + +@pytest.fixture(scope="session") +def centroid_selection(sgp, cohort_key, populate_cohort, centroid_params): + centroid_key = cohort_key.copy() + centroid_key = { + key: val + for key, val in cohort_key.items() + if key in sgp.v1.DLCCentroidSelection.primary_key + } + centroid_key.update(centroid_params) + sgp.v1.DLCCentroidSelection.insert1(centroid_key, skip_duplicates=True) + yield centroid_key + + +@pytest.fixture(scope="session") +def centroid_key(sgp, centroid_selection): + yield centroid_selection.copy() + + +@pytest.fixture(scope="session") +def populate_centroid(sgp, centroid_selection): + sgp.v1.DLCCentroid.populate(centroid_selection) + + +@pytest.fixture(scope="session") +def orient_params(sgp): + params_tbl = sgp.v1.DLCOrientationParams + params_key = {"dlc_orientation_params_name": "none"} + if len(params_tbl & params_key) == 0: + params_tbl.insert1( + { + **params_key, + "params": { + "orient_method": "none", + "bodypart1": "whiteLED", + "orientation_smoothing_std_dev": 0.001, + }, + } + ) + return params_key + + +@pytest.fixture(scope="session") +def orient_selection(sgp, cohort_key, orient_params): + orient_key = { + key: val + for key, val in cohort_key.items() + if key in sgp.v1.DLCOrientationSelection.primary_key + } + orient_key.update(orient_params) + sgp.v1.DLCOrientationSelection().insert1(orient_key, skip_duplicates=True) + yield orient_key + + +@pytest.fixture(scope="session") +def orient_key(sgp, orient_selection): + yield orient_selection.copy() + + +@pytest.fixture(scope="session") +def populate_orient(sgp, orient_selection): + sgp.v1.DLCOrientation().populate(orient_selection) + yield + + +@pytest.fixture(scope="session") +def dlc_selection(sgp, centroid_key, orient_key, populate_orient): + dlc_key = { + key: val + for key, val in centroid_key.items() + if key in sgp.v1.DLCPosV1.primary_key + } + dlc_key.update( + { + "dlc_si_cohort_centroid": centroid_key[ + "dlc_si_cohort_selection_name" + ], + "dlc_si_cohort_orientation": orient_key[ + "dlc_si_cohort_selection_name" + ], + "dlc_orientation_params_name": orient_key[ + "dlc_orientation_params_name" + ], + } + ) + sgp.v1.DLCPosSelection().insert1(dlc_key, skip_duplicates=True) + yield dlc_key + + +@pytest.fixture(scope="session") +def dlc_key(sgp, dlc_selection): + yield dlc_selection.copy() + + +@pytest.fixture(scope="session") +def populate_dlc(sgp, dlc_key): + sgp.v1.DLCPosV1().populate(dlc_key) + yield diff --git a/tests/position/__init__.py b/tests/position/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/position/conftest.py b/tests/position/conftest.py index 83b232bba..1aaec3384 100644 --- a/tests/position/conftest.py +++ b/tests/position/conftest.py @@ -1,414 +1,32 @@ -from pathlib import Path -from shutil import rmtree as shutil_rmtree - +""" +The following lines are not used in the course of regular pose processing and +can be removed so long as other functionality is not impacted. + +position_merge.py: 106-107, 110-123, 139-262 +dlc_decorators.py: 11, 16-18, 22 +dlc_reader.py : + 24, 38, 44-45, 51, 57-58, 61, 70, 74, 80-81, 135-137, 146, 149-162, 214, + 218 +dlc_utils.py : + 58, 61, 69, 72, 97-100, 104, 149-161, 232-235, 239-241, 246, 259, 280, + 293-305, 310-316, 328-341, 356-373, 395, 404, 480, 487-488, 530, 548-561, + 594-601, 611-612, 641-657, 682-736, 762-772, 787, 809-1286 + +TODO: tests for +pose_estimat 51-71, 102, 115, 256, 345-366 +position.py 53, 99, 114, 119, 197-198, 205-219, 349, 353-355, 360, 382, 385, 407, 443-466 +project.py 45-54, 128-205, 250-255, 259, 278-303, 316, 347, 361-413, 425, 457, 476-479, 486-489, 514-555, 582, 596 +selection.py 213, 282, 308-417 +training.py 55, 67-73, 85-87, 113, 143-144, 161, 207-210 +es_position.py 67, 282-283, 361-362, 496, 502-503 + +""" + +from itertools import product as iter_prodect + +import numpy as np +import pandas as pd import pytest -from deeplabcut.utils.auxiliaryfunctions import read_config, write_config - - -@pytest.fixture(scope="session") -def bodyparts(sgp): - bps = ["whiteLED", "tailBase", "tailMid", "tailTip"] - sgp.v1.BodyPart.insert( - [{"bodypart": bp, "bodypart_description": "none"} for bp in bps], - skip_duplicates=True, - ) - - yield bps - - -@pytest.fixture(scope="session") -def dlc_project_tbl(sgp): - yield sgp.v1.DLCProject() - - -@pytest.fixture(scope="session") -def insert_project( - verbose_context, - teardown, - dlc_project_tbl, - common, - bodyparts, - mini_copy_name, -): - team_name = "sc_eb" - common.LabTeam.insert1({"team_name": team_name}, skip_duplicates=True) - with verbose_context: - project_key = dlc_project_tbl.insert_new_project( - project_name="pytest_proj", - bodyparts=bodyparts, - lab_team=team_name, - frames_per_video=100, - video_list=[ - {"nwb_file_name": mini_copy_name, "epoch": 0}, - {"nwb_file_name": mini_copy_name, "epoch": 1}, - ], - skip_duplicates=True, - ) - config_path = (dlc_project_tbl & project_key).fetch1("config_path") - cfg = read_config(config_path) - cfg.update( - { - "numframes2pick": 2, - "maxiters": 2, - "scorer": team_name, - "skeleton": [ - ["whiteLED"], - [ - ["tailMid", "tailMid"], - ["tailBase", "tailBase"], - ["tailTip", "tailTip"], - ], - ], # eb's has video_sets: {1: {'crop': [0, 1260, 0, 728]}} - } - ) - - write_config(config_path, cfg) - - yield project_key, cfg, config_path - - if teardown: - (dlc_project_tbl & project_key).delete(safemode=False) - shutil_rmtree(str(Path(config_path).parent)) - - -@pytest.fixture(scope="session") -def project_key(insert_project): - yield insert_project[0] - - -@pytest.fixture(scope="session") -def dlc_config(insert_project): - yield insert_project[1] - - -@pytest.fixture(scope="session") -def config_path(insert_project): - yield insert_project[2] - - -@pytest.fixture(scope="session") -def project_dir(config_path): - yield Path(config_path).parent - - -@pytest.fixture(scope="session") -def extract_frames( - verbose_context, dlc_project_tbl, project_key, dlc_config, project_dir -): - with verbose_context: - dlc_project_tbl.run_extract_frames( - project_key, userfeedback=False, mode="automatic" - ) - vid_name = list(dlc_config["video_sets"].keys())[0].split("/")[-1] - label_dir = project_dir / "labeled-data" / vid_name.split(".")[0] - yield label_dir - - -@pytest.fixture(scope="session") -def labeled_vid_dir(extract_frames): - yield extract_frames - - -@pytest.fixture(scope="session") -def fix_downloaded(labeled_vid_dir, project_dir): - """Grabs CollectedData and img files from project_dir, moves to labeled""" - for file in project_dir.parent.parent.glob("*"): - if file.is_dir(): - continue - dest = labeled_vid_dir / file.name - if dest.exists(): - dest.unlink() - dest.write_bytes(file.read_bytes()) - # TODO: revert to rename before merge - # file.rename(labeled_vid_dir / file.name) - - yield - - -@pytest.fixture(scope="session") -def add_training_files(dlc_project_tbl, project_key, fix_downloaded): - dlc_project_tbl.add_training_files(project_key, skip_duplicates=True) - yield - - -@pytest.fixture(scope="session") -def training_params_key(verbose_context, sgp, project_key): - training_params_name = "pytest" - with verbose_context: - sgp.v1.DLCModelTrainingParams.insert_new_params( - paramset_name=training_params_name, - params={ - "trainingsetindex": 0, - "shuffle": 1, - "gputouse": None, - "TFGPUinference": False, - "net_type": "resnet_50", - "augmenter_type": "imgaug", - }, - skip_duplicates=True, - ) - yield {"dlc_training_params_name": training_params_name} - - -@pytest.fixture(scope="session") -def model_train_key(sgp, project_key, training_params_key): - _ = project_key.pop("config_path", None) - model_train_key = { - **project_key, - **training_params_key, - "training_id": 0, - } - sgp.v1.DLCModelTrainingSelection().insert1( - { - **model_train_key, - "model_prefix": "", - }, - skip_duplicates=True, - ) - yield model_train_key - - -@pytest.fixture(scope="session") -def populate_training(sgp, fix_downloaded, model_train_key, add_training_files): - train_tbl = sgp.v1.DLCModelTraining - if len(train_tbl & model_train_key) == 0: - _ = add_training_files - _ = fix_downloaded - sgp.v1.DLCModelTraining.populate(model_train_key) - yield model_train_key - - -@pytest.fixture(scope="session") -def model_source_key(sgp, model_train_key, populate_training): - yield (sgp.v1.DLCModelSource & model_train_key).fetch1("KEY") - - -@pytest.fixture(scope="session") -def model_key(sgp, model_source_key): - model_key = {**model_source_key, "dlc_model_params_name": "default"} - _ = sgp.v1.DLCModelParams.get_default() - sgp.v1.DLCModelSelection().insert1(model_key, skip_duplicates=True) - yield model_key - - -@pytest.fixture(scope="session") -def populate_model(sgp, model_key): - model_tbl = sgp.v1.DLCModel - if model_tbl & model_key: - yield - else: - sgp.v1.DLCModel.populate(model_key) - yield - - -@pytest.fixture(scope="session") -def pose_estimation_key(sgp, mini_copy_name, populate_model, model_key): - yield sgp.v1.DLCPoseEstimationSelection.insert_estimation_task( - { - "nwb_file_name": mini_copy_name, - "epoch": 1, - "video_file_num": 0, - **model_key, - }, - task_mode="trigger", # trigger or load - params={"gputouse": None, "videotype": "mp4", "TFGPUinference": False}, - ) - - -@pytest.fixture(scope="session") -def populate_pose_estimation(sgp, pose_estimation_key): - pose_est_tbl = sgp.v1.DLCPoseEstimation - if pose_est_tbl & pose_estimation_key: - yield - else: - pose_est_tbl.populate(pose_estimation_key) - yield - - -@pytest.fixture(scope="session") -def si_params_name(sgp, populate_pose_estimation): - params_name = "low_bar" - params_tbl = sgp.v1.DLCSmoothInterpParams - # if len(params_tbl & {"dlc_si_params_name": params_name}) == 0: - if True: # TODO: remove before merge - nan_params = params_tbl.get_nan_params() - nan_params["dlc_si_params_name"] = params_name - nan_params["params"].update( - { - "likelihood_thresh": 0.4, - "max_cm_between_pts": 100, - "num_inds_to_span": 50, - } - ) - params_tbl.insert1(nan_params, skip_duplicates=True) - - yield params_name - - -@pytest.fixture(scope="session") -def si_key(sgp, bodyparts, si_params_name, pose_estimation_key): - key = { - key: val - for key, val in pose_estimation_key.items() - if key in sgp.v1.DLCSmoothInterpSelection.primary_key - } - sgp.v1.DLCSmoothInterpSelection.insert( - [ - { - **key, - "bodypart": bodypart, - "dlc_si_params_name": si_params_name, - } - for bodypart in bodyparts[:1] - ], - skip_duplicates=True, - ) - yield key - - -@pytest.fixture(scope="session") -def populate_si(sgp, si_key, populate_pose_estimation): - sgp.v1.DLCSmoothInterp.populate() - yield - - -@pytest.fixture(scope="session") -def cohort_selection(sgp, si_key, si_params_name): - cohort_key = { - k: v - for k, v in { - **si_key, - "dlc_si_cohort_selection_name": "whiteLED", - "bodyparts_params_dict": { - "whiteLED": si_params_name, - }, - }.items() - if k not in ["bodypart", "dlc_si_params_name"] - } - sgp.v1.DLCSmoothInterpCohortSelection().insert1( - cohort_key, skip_duplicates=True - ) - yield cohort_key - - -@pytest.fixture(scope="session") -def cohort_key(sgp, cohort_selection): - yield cohort_selection.copy() - - -@pytest.fixture(scope="session") -def populate_cohort(sgp, cohort_selection, populate_si): - sgp.v1.DLCSmoothInterpCohort.populate(cohort_selection) - - -@pytest.fixture(scope="session") -def centroid_params(sgp): - params_tbl = sgp.v1.DLCCentroidParams - params_key = {"dlc_centroid_params_name": "one_test"} - if len(params_tbl & params_key) == 0: - params_tbl.insert1( - { - **params_key, - "params": { - "centroid_method": "one_pt_centroid", - "points": {"point1": "whiteLED"}, - "interpolate": True, - "interp_params": {"max_cm_to_interp": 100}, - "smooth": True, - "smoothing_params": { - "smoothing_duration": 0.05, - "smooth_method": "moving_avg", - }, - "max_LED_separation": 50, - "speed_smoothing_std_dev": 0.100, - }, - } - ) - yield params_key - - -@pytest.fixture(scope="session") -def centroid_selection(sgp, cohort_key, populate_cohort, centroid_params): - centroid_key = cohort_key.copy() - centroid_key = { - key: val - for key, val in cohort_key.items() - if key in sgp.v1.DLCCentroidSelection.primary_key - } - centroid_key.update(centroid_params) - sgp.v1.DLCCentroidSelection.insert1(centroid_key, skip_duplicates=True) - yield centroid_key - - -@pytest.fixture(scope="session") -def centroid_key(sgp, centroid_selection): - yield centroid_selection.copy() - - -@pytest.fixture(scope="session") -def populate_centroid(sgp, centroid_selection): - sgp.v1.DLCCentroid.populate(centroid_selection) - - -@pytest.fixture(scope="session") -def orient_params(sgp): - params_tbl = sgp.v1.DLCOrientationParams - params_key = {"dlc_orientation_params_name": "none"} - if len(params_tbl & params_key) == 0: - params_tbl.insert1({**params_key, "params": {}}) - return params_key - - -@pytest.fixture(scope="session") -def orient_selection(sgp, cohort_key, orient_params): - orient_key = { - key: val - for key, val in cohort_key.items() - if key in sgp.v1.DLCOrientationSelection.primary_key - } - orient_key.update(orient_params) - sgp.v1.DLCOrientationSelection().insert1(orient_key, skip_duplicates=True) - yield orient_key - - -@pytest.fixture(scope="session") -def orient_key(sgp, orient_selection): - yield orient_selection.copy() - - -@pytest.fixture(scope="session") -def populate_orient(sgp, orient_selection): - sgp.v1.DLCOrientation().populate(orient_selection) - yield - - -@pytest.fixture(scope="session") -def dlc_selection(sgp, centroid_key, orient_key, populate_orient): - dlc_key = { - key: val - for key, val in centroid_key.items() - if key in sgp.v1.DLCPosV1.primary_key - } - dlc_key.update( - { - "dlc_si_cohort_centroid": centroid_key[ - "dlc_si_cohort_selection_name" - ], - } - ) - sgp.v1.DLCPosSelection().insert1(dlc_key, skip_duplicates=True) - yield dlc_key - - -@pytest.fixture(scope="session") -def dlc_key(sgp, dlc_selection): - yield dlc_selection.copy() - - -@pytest.fixture(scope="session") -def populate_dlc(sgp, dlc_key): - sgp.v1.DLCPosV1().populate(dlc_key) - yield @pytest.fixture(scope="session") @@ -439,3 +57,42 @@ def dlc_video_selection(sgp, dlc_key, dlc_video_params): def populate_dlc_video(sgp, dlc_video_selection): sgp.v1.DLCPosVideo.populate(dlc_video_selection) yield + + +@pytest.fixture(scope="session") +def populate_evaluation(sgp, populate_model): + sgp.v1.DLCEvaluation.populate() + yield + + +def generate_led_df(leds, inc_vals=False): + """Returns df with all combinations of 1 and np.nan for each led. + + If inc_vals is True, the values will be incremented by 1 for each non-nan""" + all_vals = list(zip(*iter_prodect([1, np.nan], repeat=len(leds)))) + n_rows = len(all_vals[0]) + indices = np.random.uniform(1.6223e09, 1.6224e09, n_rows) + + data = dict() + for led, values in zip(leds, all_vals): + data.update( + { + (led, "video_frame_id"): { + i: f for i, f in zip(indices, range(n_rows + 1)) + }, + (led, "x"): {i: v for i, v in zip(indices, values)}, + (led, "y"): {i: v for i, v in zip(indices, values)}, + } + ) + df = pd.DataFrame(data) + + if not inc_vals: + return df + + count = [0] + + def increment_count(): + count[0] += 1 + return count[0] + + return df.map(lambda x: increment_count() if x == 1 else x) diff --git a/tests/position/test_dlc_cent.py b/tests/position/test_dlc_cent.py index d83973c90..86ccba275 100644 --- a/tests/position/test_dlc_cent.py +++ b/tests/position/test_dlc_cent.py @@ -1,9 +1,12 @@ +import numpy as np import pytest from numpy import isclose as np_isclose +from .conftest import generate_led_df + @pytest.fixture(scope="session") -def centroid_df(sgp, centroid_key): +def centroid_df(sgp, centroid_key, populate_centroid): yield (sgp.v1.DLCCentroid & centroid_key).fetch1_dataframe() @@ -23,3 +26,41 @@ def test_centroid_fetch1_dataframe(centroid_df, column, exp_sum): assert np_isclose( centroid_df[column].sum(), exp_sum, atol=tolerance ), f"Sum of {column} in Centroid dataframe is not as expected" + + +@pytest.fixture(scope="session") +def params_tbl(sgp): + yield sgp.v1.DLCCentroidParams() + + +def test_insert_default_params(params_tbl): + ret = params_tbl.get_default() + assert "default" in params_tbl.fetch( + "dlc_centroid_params_name" + ), "Default params not inserted" + assert ( + ret["dlc_centroid_params_name"] == "default" + ), "Default params not inserted" + + +def test_validate_params(params_tbl): + params = params_tbl.get_default() + params["dlc_centroid_params_name"] = "test" + params_tbl.insert1(params) + + +@pytest.mark.parametrize( + "key", ["four_led_centroid", "two_pt_centroid", "one_pt_centroid"] +) +def test_centroid_calcs(key, sgp): + points = sgp.v1.position_dlc_centroid._key_to_points[key] + func = sgp.v1.position_dlc_centroid._key_to_func_dict[key] + + df = generate_led_df(points) + ret = func(df, max_LED_separation=100, points={p: p for p in points}) + + assert np.all(ret[:-1] == 1), f"Centroid calculation failed for {key}" + assert np.all(np.isnan(ret[-1])), f"Centroid calculation failed for {key}" + + with pytest.raises(KeyError): + func(df) # Missing led separation/point names diff --git a/tests/position/test_dlc_model.py b/tests/position/test_dlc_model.py index f2ad93d77..6f1ccf89d 100644 --- a/tests/position/test_dlc_model.py +++ b/tests/position/test_dlc_model.py @@ -1,3 +1,6 @@ +import pytest + + def test_model_params_default(sgp): assert sgp.v1.DLCModelParams.get_default() == { "dlc_model_params_name": "default", @@ -8,3 +11,8 @@ def test_model_params_default(sgp): "model_prefix": "", }, } + + +def test_model_input_assert(sgp): + with pytest.raises(AssertionError): + sgp.v1.DLCModelInput().insert1({"config_path": "/fake/path/"}) diff --git a/tests/position/test_dlc_orient.py b/tests/position/test_dlc_orient.py index b36a40903..826df4cf9 100644 --- a/tests/position/test_dlc_orient.py +++ b/tests/position/test_dlc_orient.py @@ -1,9 +1,45 @@ +import numpy as np import pandas as pd import pytest +from .conftest import generate_led_df + + +def test_insert_params(sgp): + params_name = "test_params" + params_key = {"dlc_orientation_params_name": params_name} + params_tbl = sgp.v1.DLCOrientationParams() + params_tbl.insert_params( + params_name=params_name, params={}, skip_duplicates=True + ) + assert params_tbl & params_key, "Failed to insert params" + + defaults = params_tbl.get_default() + assert ( + defaults.get("params", {}).get("bodypart1") == "greenLED" + ), "Failed to insert default params" + -@pytest.mark.skip(reason="Needs labeled data") def test_orient_fetch1_dataframe(sgp, orient_key, populate_orient): + """Fetches dataframe, but example data has one led, no orientation""" fetched_df = (sgp.v1.DLCOrientation & orient_key).fetch1_dataframe() assert isinstance(fetched_df, pd.DataFrame) - raise NotImplementedError + + +@pytest.mark.parametrize( + "key, points, exp_sum", + [ + ("none", ["none"], 0.0), + ("red_green_orientation", ["bodypart1", "bodypart2"], -2.356), + ("red_led_bisector", ["led1", "led2", "led3"], -1.571), + ], +) +def test_orient_calcs(sgp, key, points, exp_sum): + func = sgp.v1.position_dlc_orient._key_to_func_dict[key] + + df = generate_led_df(points, inc_vals=True) + df_sum = np.nansum(func(df, **{p: p for p in points})) + + assert np.isclose( + df_sum, exp_sum, atol=0.001 + ), f"Failed to calculate orient via {key}" diff --git a/tests/position/test_dlc_pos.py b/tests/position/test_dlc_pos.py index a58b992dd..df878c90c 100644 --- a/tests/position/test_dlc_pos.py +++ b/tests/position/test_dlc_pos.py @@ -1,4 +1,3 @@ -import pandas as pd import pytest from numpy import isclose as np_isclose diff --git a/tests/position/test_pos_merge.py b/tests/position/test_pos_merge.py index 9f957d06a..af6b17e0f 100644 --- a/tests/position/test_pos_merge.py +++ b/tests/position/test_pos_merge.py @@ -1,8 +1,26 @@ -import pandas as pd import pytest +from numpy import isclose as np_isclose -def test_pos_merge(sgp, pos_merge, populate_dlc, dlc_key): - fetched_df = (sgp.v1.PositionOutput.DLCPosV1() & dlc_key).fetch1_dataframe() - assert isinstance(fetched_df, pd.DataFrame) - raise NotImplementedError +@pytest.fixture(scope="session") +def merge_df(sgp, pos_merge, dlc_key, populate_dlc): + merge_key = (pos_merge.DLCPosV1 & dlc_key).fetch1("KEY") + yield (pos_merge & merge_key).fetch1_dataframe() + + +@pytest.mark.parametrize( + "column, exp_sum", + [ # NOTE: same as test_centroid_fetch1_dataframe + ("video_frame_ind", 36312), + ("position_x", 17987), + ("position_y", 2983), + ("velocity_x", -1.489), + ("velocity_y", 4.160), + ("speed", 12957), + ], +) +def test_merge_dlc_fetch1_dataframe(merge_df, column, exp_sum): + tolerance = abs(merge_df[column].iloc[0] * 0.1) + assert np_isclose( + merge_df[column].sum(), exp_sum, atol=tolerance + ), f"Sum of {column} in Merge.DLCPosV1 dataframe is not as expected" diff --git a/tests/utils/test_mixin.py b/tests/utils/test_mixin.py index 010abf03c..5b6beb4d0 100644 --- a/tests/utils/test_mixin.py +++ b/tests/utils/test_mixin.py @@ -41,15 +41,19 @@ def test_merge_detect(Nwbfile, pos_merge_tables): ), "Merges not detected by mixin." -def test_merge_chain_join(Nwbfile, pos_merge_tables, lin_v1, lfp_merge_key): - """Test that the mixin can join merge chains.""" - _ = lin_v1, lfp_merge_key # merge tables populated +def test_merge_chain_join( + Nwbfile, pos_merge_tables, lin_v1, lfp_merge_key, populate_dlc +): + """Test that the mixin can join merge chains. + + NOTE: This will change if more data is added to merge tables.""" + _ = lin_v1, lfp_merge_key, populate_dlc # merge tables populated all_chains = [ chains.cascade(True, direction="down") for chains in Nwbfile._merge_chains.values() ] - end_len = [len(chain[0]) for chain in all_chains if chain] + end_len = [len(chain) for chain in all_chains] assert sum(end_len) == 4, "Merge chains not joined correctly."