From 976c43571f259557c4b0ea6ac58ad39384d3004d Mon Sep 17 00:00:00 2001 From: Vladimir Date: Mon, 12 Feb 2024 11:24:00 +0100 Subject: [PATCH] Fix no dataset split error handling --- tracklab/datastruct/tracking_dataset.py | 17 +++++++++++------ .../datasets/soccernet/soccernet_game_state.py | 18 ++++++------------ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/tracklab/datastruct/tracking_dataset.py b/tracklab/datastruct/tracking_dataset.py index e7f24589..217b1d69 100644 --- a/tracklab/datastruct/tracking_dataset.py +++ b/tracklab/datastruct/tracking_dataset.py @@ -5,6 +5,15 @@ import pandas as pd +class SetsDict(dict): + def __getitem__(self, key): + if key not in self: + raise KeyError(f"Trying to access a '{key}' split of the dataset that is not available. " + f"Available splits are {list(self.keys())}. " + f"Make sur this split name is correct or is available in the dataset folder.") + return super().__getitem__(key) + + @dataclass class TrackingSet: video_metadatas: pd.DataFrame @@ -25,12 +34,8 @@ def __init__( **kwargs ): self.dataset_path = Path(dataset_path) - self.sets = sets - self.train_set = None - self.val_set = None - self.test_set = None - - sub_sampled_sets = {} + self.sets = SetsDict(sets) + sub_sampled_sets = SetsDict() for set_name, split in self.sets.items(): vid_list = vids_dict[set_name] if vids_dict is not None and set_name in vids_dict else None sub_sampled_sets[set_name] = self._subsample(split, nvid, nframes, vid_list) diff --git a/tracklab/wrappers/datasets/soccernet/soccernet_game_state.py b/tracklab/wrappers/datasets/soccernet/soccernet_game_state.py index 813ed1cc..9d731683 100644 --- a/tracklab/wrappers/datasets/soccernet/soccernet_game_state.py +++ b/tracklab/wrappers/datasets/soccernet/soccernet_game_state.py @@ -28,17 +28,12 @@ def __init__(self, download_dataset(self.dataset_path) assert self.dataset_path.exists(), f"'{self.dataset_path}' directory does not exist. Please check the path or download the dataset following the instructions here: https://github.com/SoccerNet/sn-gamestate" - train_set = load_set(self.dataset_path / "train", nvid, vids_dict.get("train", [])) if os.path.exists(self.dataset_path / "train") else None - val_set = load_set(self.dataset_path / "valid", nvid, vids_dict.get("valid", [])) if os.path.exists(self.dataset_path / "valid") else None - test_set = load_set(self.dataset_path / "test", nvid, vids_dict.get("test", [])) if (self.dataset_path / "test").exists() else None - challenge = load_set(self.dataset_path / "challenge", nvid, vids_dict.get("challenge", [])) if os.path.exists(self.dataset_path / "challenge") else None - - sets = { - "train": train_set, - "valid": val_set, - "test": test_set, - "challenge": challenge - } + sets = {} + for split in ["train", "valid", "test", "challenge"]: + if os.path.exists(self.dataset_path / split): + sets[split] = load_set(self.dataset_path / split, nvid, vids_dict.get(split, [])) + else: + log.warning(f"Warning: The '{split}' set does not exist in the SoccerNetGS dataset at '{self.dataset_path}'") # We pass 'nvid=-1', 'vids_dict=None' because video subsampling is already done in the load_set function super().__init__(dataset_path, sets, nvid=-1, vids_dict=None, *args, **kwargs) @@ -202,7 +197,6 @@ def load_set(dataset_path, nvid=-1, vids_filter_set=None): annotations_pitch_camera_list = [] detections_list = [] categories_list = [] - image_gt_challenge = [] split = os.path.basename(dataset_path) # Get the split name from the dataset path video_list = os.listdir(dataset_path) video_list.sort()