From 1ec054608d0a11678572206e92eae89d9ddeeb0e Mon Sep 17 00:00:00 2001 From: Ahmad Amine Date: Tue, 8 Oct 2024 02:44:38 -0400 Subject: [PATCH] Add track creation from path --- f1tenth_gym/envs/track/track.py | 60 +++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/f1tenth_gym/envs/track/track.py b/f1tenth_gym/envs/track/track.py index c6a0d973..eed8bb16 100644 --- a/f1tenth_gym/envs/track/track.py +++ b/f1tenth_gym/envs/track/track.py @@ -153,6 +153,66 @@ def from_track_name(track: str): print(ex) raise FileNotFoundError(f"It could not load track {track}") from ex + @staticmethod + def from_track_path(path: pathlib.Path): + """ + Load track from track path. + + Parameters + ---------- + path : pathlib.Path + path to the track yaml file + + Returns + ------- + Track + track object + + Raises + ------ + FileNotFoundError + if the track cannot be loaded + """ + try: + if type(path) is str: + path = pathlib.Path(path) + + track_spec = Track.load_spec( + track=path.stem, filespec=path + ) + + # load occupancy grid + # Image path is from path + image name from track_spec + image_path = path.parent / track_spec.image + image = Image.open(image_path).transpose(Transpose.FLIP_TOP_BOTTOM) + occupancy_map = np.array(image).astype(np.float32) + occupancy_map[occupancy_map <= 128] = 0.0 + occupancy_map[occupancy_map > 128] = 255.0 + + # if exists, load centerline + if (path / f"{path.stem}_centerline.csv").exists(): + centerline = Raceline.from_centerline_file(path / f"{path.stem}_centerline.csv") + else: + centerline = None + + # if exists, load raceline + if (path / f"{path.stem}_raceline.csv").exists(): + raceline = Raceline.from_raceline_file(path / f"{path.stem}_raceline.csv") + else: + raceline = centerline + + return Track( + spec=track_spec, + filepath=str(path.absolute()), + ext=image_path.suffix, + occupancy_map=occupancy_map, + centerline=centerline, + raceline=raceline, + ) + except Exception as ex: + print(ex) + raise FileNotFoundError(f"It could not load track {path}") from ex + @staticmethod def from_refline(x: np.ndarray, y: np.ndarray, velx: np.ndarray): """