Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move edt as part of track asset #117

Open
wants to merge 10 commits into
base: v1.0.0
Choose a base branch
from
43 changes: 43 additions & 0 deletions examples/create_edt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from f110_gym.envs.track import Track
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from f110_gym.envs.track import Track
from f1tenth_gym.envs.track import Track

from scipy.ndimage import distance_transform_edt as edt
import numpy as np

DEFAULT_MAP_NAMES = [
"Austin",
"BrandsHatch",
"Budapest",
"Catalunya",
"Hockenheim",
"IMS",
"Melbourne",
"MexicoCity",
"Montreal",
"Monza",
"MoscowRaceway",
"Nuerburgring",
"Oschersleben",
"Sakhir",
"SaoPaulo",
"Sepang",
"Shanghai",
"Silverstone",
"Sochi",
"Spa",
"Spielberg",
"YasMarina",
"Zandvoort",
]

for track_name in DEFAULT_MAP_NAMES:
print("Loading a map without edt, a warning should appear")
hzheng40 marked this conversation as resolved.
Show resolved Hide resolved
track = Track.from_track_name(track_name)
occupancy_map = track.occupancy_map
resolution = track.spec.resolution

dt = resolution * edt(occupancy_map)

# saving
np.save(track.filepath, dt)

print("Loading a map with edt, warning should no longer appear")
hzheng40 marked this conversation as resolved.
Show resolved Hide resolved
track_wedt = Track.from_track_name(track_name)
5 changes: 4 additions & 1 deletion gym/f110_gym/envs/laser_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,10 @@ def set_map(self, map: str | Track):
self.orig_c = np.cos(self.origin[2])

# get the distance transform
self.dt = get_dt(self.map_img, self.map_resolution)
if self.track.edt is not None:
self.dt = self.track.edt
else:
self.dt = get_dt(self.map_img, self.map_resolution)

return True

Expand Down
12 changes: 12 additions & 0 deletions gym/f110_gym/envs/track/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Track:
filepath: str
ext: str
occupancy_map: np.ndarray
edt: np.ndarray
centerline: Raceline
raceline: Raceline

Expand All @@ -40,6 +41,7 @@ def __init__(
filepath: str,
ext: str,
occupancy_map: np.ndarray,
edt: Optional[np.ndarray] = None,
centerline: Optional[Raceline] = None,
raceline: Optional[Raceline] = None,
):
Expand All @@ -56,6 +58,8 @@ def __init__(
file extension of the track image file
occupancy_map : np.ndarray
occupancy grid map
edt : np.ndarray
distance transform of the map
centerline : Raceline, optional
centerline of the track, by default None
raceline : Raceline, optional
Expand All @@ -65,6 +69,7 @@ def __init__(
self.filepath = filepath
self.ext = ext
self.occupancy_map = occupancy_map
self.edt = edt
self.centerline = centerline
self.raceline = raceline

Expand Down Expand Up @@ -125,6 +130,12 @@ def from_track_name(track: str):
occupancy_map[occupancy_map <= 128] = 0.0
occupancy_map[occupancy_map > 128] = 255.0

# if exists, load edt
if (track_dir / f"{track}_map.npy").exists():
edt = np.load(track_dir / f"{track}_map.npy")
else:
edt = None
hzheng40 marked this conversation as resolved.
Show resolved Hide resolved

hzheng40 marked this conversation as resolved.
Show resolved Hide resolved
# if exists, load centerline
if (track_dir / f"{track}_centerline.csv").exists():
centerline = Raceline.from_centerline_file(
Expand All @@ -146,6 +157,7 @@ def from_track_name(track: str):
filepath=str((track_dir / map_filename.stem).absolute()),
ext=map_filename.suffix,
occupancy_map=occupancy_map,
edt=edt,
centerline=centerline,
raceline=raceline,
)
Expand Down
2 changes: 1 addition & 1 deletion gym/f110_gym/envs/track/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def find_track_dir(track_name: str) -> pathlib.Path:
FileNotFoundError
if no map directory matching the track name is found
"""
map_dir = pathlib.Path(__file__).parent.parent.parent.parent / "maps"
map_dir = pathlib.Path(__file__).parent.parent.parent.parent.parent / "maps"
hzheng40 marked this conversation as resolved.
Show resolved Hide resolved

if not (map_dir / track_name).exists():
print("Downloading Files for: " + track_name)
Expand Down
Loading