Skip to content

Commit

Permalink
Fix no dataset split error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
VlSomers committed Feb 12, 2024
1 parent 278aedb commit 976c435
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
17 changes: 11 additions & 6 deletions tracklab/datastruct/tracking_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@
import pandas as pd


class SetsDict(dict):
def __getitem__(self, key):
if key not in self:
raise KeyError(f"Trying to access a '{key}' split of the dataset that is not available. "
f"Available splits are {list(self.keys())}. "
f"Make sur this split name is correct or is available in the dataset folder.")
return super().__getitem__(key)


@dataclass
class TrackingSet:
video_metadatas: pd.DataFrame
Expand All @@ -25,12 +34,8 @@ def __init__(
**kwargs
):
self.dataset_path = Path(dataset_path)
self.sets = sets
self.train_set = None
self.val_set = None
self.test_set = None

sub_sampled_sets = {}
self.sets = SetsDict(sets)
sub_sampled_sets = SetsDict()
for set_name, split in self.sets.items():
vid_list = vids_dict[set_name] if vids_dict is not None and set_name in vids_dict else None
sub_sampled_sets[set_name] = self._subsample(split, nvid, nframes, vid_list)
Expand Down
18 changes: 6 additions & 12 deletions tracklab/wrappers/datasets/soccernet/soccernet_game_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,12 @@ def __init__(self,
download_dataset(self.dataset_path)
assert self.dataset_path.exists(), f"'{self.dataset_path}' directory does not exist. Please check the path or download the dataset following the instructions here: https://github.com/SoccerNet/sn-gamestate"

train_set = load_set(self.dataset_path / "train", nvid, vids_dict.get("train", [])) if os.path.exists(self.dataset_path / "train") else None
val_set = load_set(self.dataset_path / "valid", nvid, vids_dict.get("valid", [])) if os.path.exists(self.dataset_path / "valid") else None
test_set = load_set(self.dataset_path / "test", nvid, vids_dict.get("test", [])) if (self.dataset_path / "test").exists() else None
challenge = load_set(self.dataset_path / "challenge", nvid, vids_dict.get("challenge", [])) if os.path.exists(self.dataset_path / "challenge") else None

sets = {
"train": train_set,
"valid": val_set,
"test": test_set,
"challenge": challenge
}
sets = {}
for split in ["train", "valid", "test", "challenge"]:
if os.path.exists(self.dataset_path / split):
sets[split] = load_set(self.dataset_path / split, nvid, vids_dict.get(split, []))
else:
log.warning(f"Warning: The '{split}' set does not exist in the SoccerNetGS dataset at '{self.dataset_path}'")

# We pass 'nvid=-1', 'vids_dict=None' because video subsampling is already done in the load_set function
super().__init__(dataset_path, sets, nvid=-1, vids_dict=None, *args, **kwargs)
Expand Down Expand Up @@ -202,7 +197,6 @@ def load_set(dataset_path, nvid=-1, vids_filter_set=None):
annotations_pitch_camera_list = []
detections_list = []
categories_list = []
image_gt_challenge = []
split = os.path.basename(dataset_path) # Get the split name from the dataset path
video_list = os.listdir(dataset_path)
video_list.sort()
Expand Down

0 comments on commit 976c435

Please sign in to comment.