Skip to content

Commit

Permalink
format, ready to launch SB3 td3
Browse files Browse the repository at this point in the history
  • Loading branch information
pengzhenghao committed May 13, 2024
1 parent 07e758f commit a486d64
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions metadrive/envs/multigoal_intersection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
This file implement an intersection environment with multiple goals.
"""
Expand Down Expand Up @@ -68,7 +67,9 @@ def observe(self, vehicle):

def state_observe(self, vehicle):
# update out of road
info = np.zeros([EGO_STATE_DIM, ])
info = np.zeros([
EGO_STATE_DIM,
])

# The velocity of target vehicle
info[0] = clip((vehicle.speed_km_h + 1) / (vehicle.max_speed_km_h + 1), 0.0, 1.0)
Expand All @@ -91,22 +92,26 @@ def state_observe(self, vehicle):
return info

def side_detector_observe(self, vehicle):
return np.asarray(self.engine.get_sensor("side_detector").perceive(
vehicle,
num_lasers=vehicle.config["side_detector"]["num_lasers"],
distance=vehicle.config["side_detector"]["distance"],
physics_world=vehicle.engine.physics_world.static_world,
show=vehicle.config["show_side_detector"],
).cloud_points)
return np.asarray(
self.engine.get_sensor("side_detector").perceive(
vehicle,
num_lasers=vehicle.config["side_detector"]["num_lasers"],
distance=vehicle.config["side_detector"]["distance"],
physics_world=vehicle.engine.physics_world.static_world,
show=vehicle.config["show_side_detector"],
).cloud_points
)

def lane_line_detector_observe(self, vehicle):
return np.asarray(self.engine.get_sensor("lane_line_detector").perceive(
vehicle,
vehicle.engine.physics_world.static_world,
num_lasers=vehicle.config["lane_line_detector"]["num_lasers"],
distance=vehicle.config["lane_line_detector"]["distance"],
show=vehicle.config["show_lane_line_detector"],
).cloud_points)
return np.asarray(
self.engine.get_sensor("lane_line_detector").perceive(
vehicle,
vehicle.engine.physics_world.static_world,
num_lasers=vehicle.config["lane_line_detector"]["num_lasers"],
distance=vehicle.config["lane_line_detector"]["distance"],
show=vehicle.config["show_lane_line_detector"],
).cloud_points
)

def vehicle_detector_observe(self, vehicle):
cloud_points, detected_objects = self.engine.get_sensor("lidar").perceive(
Expand Down Expand Up @@ -230,7 +235,6 @@ def default_config(cls):
CustomizedIntersection,
], lane_num=1, lane_width=3.5
),

"agent_observation": CustomizedObservation,

# Even though the map will not change, the traffic flow will change.
Expand All @@ -245,7 +249,6 @@ def default_config(cls):

# Turn off vehicle's own navigation module.
"side_detector": dict(num_lasers=SIDE_DETECT, distance=50), # laser num, distance

"lidar": dict(num_lasers=VEHICLE_DETECT, distance=50),

# To avoid goal-dependent lane detection, we use Lidar to detect distance to nearby lane lines.
Expand Down Expand Up @@ -320,7 +323,6 @@ def _reward_per_navigation(self, vehicle, navi, goal_name):
reward = -self.config["crash_object_penalty"]
return reward, navi.route_completion


def reward_function(self, vehicle_id: str):
"""
Compared to the original reward_function, we add goal-dependent reward to info dict.
Expand Down Expand Up @@ -491,4 +493,4 @@ def done_function(self, vehicle_id: str):
env.reset()
s = 0
finally:
env.close()
env.close()

0 comments on commit a486d64

Please sign in to comment.