Skip to content

Commit

Permalink
optimize the reward structure, now:
Browse files Browse the repository at this point in the history
===== timestep 220 =====
route completion:
	route_completion/goals/default: 0.92
	route_completion/goals/go_straight: 0.50
	route_completion/goals/left_turn: 0.46
	route_completion/goals/right_turn: 0.92
	route_completion/goals/u_turn: 0.53

reward:
	reward/default_reward: 1.27
	reward/goal_agnostic_reward: 0.05
	reward/goals/default: 1.27
	reward/goals/go_straight: 0.14
	reward/goals/left_turn: 0.06
	reward/goals/right_turn: 1.27
	reward/goals/u_turn: 0.01
=======================
  • Loading branch information
pengzhenghao committed May 12, 2024
1 parent 686957d commit d40dfc2
Showing 1 changed file with 64 additions and 41 deletions.
105 changes: 64 additions & 41 deletions metadrive/envs/multigoal_intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,38 @@ def _get_reset_return(self, reset_info):

return o, i

def _reward_per_navigation(self, vehicle, navi, goal_name):
"""Compute the reward for the given goal. goal_name='default' means we use the vehicle's own navigation."""
reward = 0.0

# Get goal-dependent information
if navi.current_lane in navi.current_ref_lanes:
current_lane = navi.current_lane
positive_road = 1
else:
current_lane = navi.current_ref_lanes[0]
current_road = navi.current_road
positive_road = 1 if not current_road.is_negative_road() else -1
long_last, _ = current_lane.local_coordinates(vehicle.last_position)
long_now, lateral_now = current_lane.local_coordinates(vehicle.position)

# Reward for moving forward in current lane
reward += self.config["driving_reward"] * (long_now - long_last) * positive_road

# Reward for speed, sign determined by whether in the correct lanes (instead of driving in the wrong
# direction).
reward += self.config["speed_reward"] * (vehicle.speed_km_h / vehicle.max_speed_km_h) * positive_road
if self._is_arrive_destination(vehicle, goal_name):
reward = +self.config["success_reward"]
elif self._is_out_of_road(vehicle):
reward = -self.config["out_of_road_penalty"]
elif vehicle.crash_vehicle:
reward = -self.config["crash_vehicle_penalty"]
elif vehicle.crash_object:
reward = -self.config["crash_object_penalty"]
return reward, navi.route_completion


def reward_function(self, vehicle_id: str):
"""
Compared to the original reward_function, we add goal-dependent reward to info dict.
Expand Down Expand Up @@ -315,37 +347,19 @@ def reward_function(self, vehicle_id: str):
for goal_name in self.engine.goal_manager.goals.keys():
navi = self.engine.goal_manager.get_navigation(goal_name)
prefix = goal_name
reward = 0.0

# Get goal-dependent information
if navi.current_lane in navi.current_ref_lanes:
current_lane = navi.current_lane
positive_road = 1
else:
current_lane = navi.current_ref_lanes[0]
current_road = navi.current_road
positive_road = 1 if not current_road.is_negative_road() else -1
long_last, _ = current_lane.local_coordinates(vehicle.last_position)
long_now, lateral_now = current_lane.local_coordinates(vehicle.position)

# Reward for moving forward in current lane
reward += self.config["driving_reward"] * (long_now - long_last) * positive_road

# Reward for speed, sign determined by whether in the correct lanes (instead of driving in the wrong
# direction).
reward += self.config["speed_reward"] * (vehicle.speed_km_h / vehicle.max_speed_km_h) * positive_road
if self._is_arrive_destination(vehicle, goal_name):
reward = +self.config["success_reward"]
elif self._is_out_of_road(vehicle):
reward = -self.config["out_of_road_penalty"]
elif vehicle.crash_vehicle:
reward = -self.config["crash_vehicle_penalty"]
elif vehicle.crash_object:
reward = -self.config["crash_object_penalty"]
step_info[f"reward/goals/{prefix}"] = reward
step_info[f"route_completion/goals/{prefix}"] = navi.route_completion

return goal_agnostic_reward, step_info
reward, route_completion = self._reward_per_navigation(vehicle, navi, goal_name)
step_info[f"reward/goals/{prefix}"] = reward + goal_agnostic_reward
step_info[f"route_completion/goals/{prefix}"] = route_completion

default_reward, default_rc = self._reward_per_navigation(vehicle, vehicle.navigation, "default")
step_info[f"reward/goals/default"] = default_reward + goal_agnostic_reward
step_info[f"route_completion/goals/default"] = default_rc

default_reward = goal_agnostic_reward + default_reward
step_info[f"reward/goal_agnostic_reward"] = goal_agnostic_reward
step_info[f"reward/default_reward"] = default_reward

return default_reward, step_info

def _is_arrive_destination(self, vehicle, goal_name=None):
"""
Expand All @@ -364,7 +378,11 @@ def _is_arrive_destination(self, vehicle, goal_name=None):
ret = ret or self._is_arrive_destination(vehicle, name)
return ret

navi = self.engine.goal_manager.get_navigation(goal_name)
if goal_name == "default":
navi = self.vehicle.navigation
else:
navi = self.engine.goal_manager.get_navigation(goal_name)

long, lat = navi.final_lane.local_coordinates(vehicle.position)
flag = (navi.final_lane.length - 5 < long < navi.final_lane.length + 5) and (
navi.get_current_lane_width() / 2 >= lat >=
Expand All @@ -387,8 +405,8 @@ def done_function(self, vehicle_id: str):

if __name__ == "__main__":
config = dict(
use_render=False,
manual_control=False,
use_render=True,
manual_control=True,
vehicle_config=dict(show_lidar=False, show_navi_mark=True, show_line_to_navi_mark=True),
accident_prob=1.0,
decision_repeat=5,
Expand Down Expand Up @@ -425,13 +443,18 @@ def done_function(self, vehicle_id: str):
if k.startswith("reward/goals"):
episode_rewards[k] += v

# if s % 20 == 0:
# info = {k: info[k] for k in sorted(info.keys())}
# print('\n===== timestep {} ====='.format(s))
# for k, v in info.items():
# if k.startswith("obs/goals/"):
# print(f"{k}: {v:.2f}")
# print('=======================')
if s % 20 == 0:
print('\n===== timestep {} ====='.format(s))
print('route completion:')
for k in sorted(info.keys()):
if k.startswith("route_completion/goals/"):
print(f"\t{k}: {info[k]:.2f}")

print('\nreward:')
for k in sorted(info.keys()):
if k.startswith("reward/"):
print(f"\t{k}: {info[k]:.2f}")
print('=======================')

if done:
print('\n===== timestep {} ====='.format(s))
Expand Down

0 comments on commit d40dfc2

Please sign in to comment.