Skip to content

Commit

Permalink
allow change radius of the intersection; set accident_prob=1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
pengzhenghao committed Apr 17, 2024
1 parent d887fbb commit 6a9a7b7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
1 change: 1 addition & 0 deletions metadrive/component/pgblock/intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,4 @@ def get_intermediate_spawn_lanes(self):
class InterSectionWithUTurn(InterSection):
ID = "U"
_enable_u_turn_flag = True
SOCKET_NUM = 4
49 changes: 31 additions & 18 deletions metadrive/envs/multigoal_intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,32 @@

from metadrive import MetaDriveEnv
from metadrive.component.navigation_module.node_network_navigation import NodeNetworkNavigation
from metadrive.component.pg_space import ParameterSpace, Parameter, ConstantSpace, DiscreteSpace
from metadrive.component.pgblock.first_block import FirstPGBlock
from metadrive.component.pgblock.intersection import InterSectionWithUTurn
from metadrive.component.road_network import Road
from metadrive.constants import DEFAULT_AGENT
from metadrive.engine.logger import get_logger
from metadrive.manager.base_manager import BaseManager

from metadrive.envs.varying_dynamics_env import VaryingDynamicsAgentManager, VaryingDynamicsConfig
from metadrive.manager.base_manager import BaseManager

logger = get_logger()


class CustomizedIntersection(InterSectionWithUTurn):
PARAMETER_SPACE = ParameterSpace(
{

# changed from 10 to 8:
Parameter.radius: ConstantSpace(8),

# unchanged:
Parameter.change_lane_num: DiscreteSpace(min=0, max=1),
Parameter.decrease_increase: DiscreteSpace(min=0, max=1)
}
)


class MultiGoalIntersectionNavigationManager(BaseManager):
"""
This manager is responsible for managing multiple navigation modules, each of which is responsible for guiding the
Expand All @@ -28,16 +42,16 @@ class MultiGoalIntersectionNavigationManager(BaseManager):
GOALS = {
"u_turn": (-Road(FirstPGBlock.NODE_2, FirstPGBlock.NODE_3)).end_node,
"right_turn": Road(
InterSectionWithUTurn.node(block_idx=1, part_idx=0, road_idx=0),
InterSectionWithUTurn.node(block_idx=1, part_idx=0, road_idx=1)
CustomizedIntersection.node(block_idx=1, part_idx=0, road_idx=0),
CustomizedIntersection.node(block_idx=1, part_idx=0, road_idx=1)
).end_node,
"go_straight": Road(
InterSectionWithUTurn.node(block_idx=1, part_idx=1, road_idx=0),
InterSectionWithUTurn.node(block_idx=1, part_idx=1, road_idx=1)
CustomizedIntersection.node(block_idx=1, part_idx=1, road_idx=0),
CustomizedIntersection.node(block_idx=1, part_idx=1, road_idx=1)
).end_node,
"left_turn": Road(
InterSectionWithUTurn.node(block_idx=1, part_idx=2, road_idx=0),
InterSectionWithUTurn.node(block_idx=1, part_idx=2, road_idx=1)
CustomizedIntersection.node(block_idx=1, part_idx=2, road_idx=0),
CustomizedIntersection.node(block_idx=1, part_idx=2, road_idx=1)
).end_node,
}

Expand Down Expand Up @@ -91,7 +105,6 @@ class MultiGoalIntersectionEnv(MetaDriveEnv):
This environment is an intersection with multiple goals. We provide the reward function, observation, termination
conditions for each goal in the info dict returned by env.reset and env.step, with prefix "goals/{goal_name}/".
"""

@classmethod
def default_config(cls):
config = MetaDriveEnv.default_config()
Expand All @@ -101,14 +114,20 @@ def default_config(cls):

# Set the map to an Intersection
"start_seed": 0,
"map": "U",

# Disable the shortcut config for map.
"map": None,
"map_config": dict(
type="block_sequence", config=[
CustomizedIntersection,
], lane_num=2, lane_width=3.5
),

# Even though the map will not change, the traffic flow will change.
"num_scenarios": 1000,

# Remove all traffic vehicles for now.
"traffic_density": 0.2,

"vehicle_config": {

# Remove navigation arrows in the window as we are in multi-goal environment.
Expand Down Expand Up @@ -256,14 +275,8 @@ def done_function(self, vehicle_id: str):
use_render=True,
manual_control=True,
vehicle_config=dict(show_lidar=False, show_navi_mark=True, show_line_to_navi_mark=True),
map_config=dict(
type="block_sequence",
config="U",
lane_num=2,
lane_width=3.5
),
accident_prob=1.0,
decision_repeat=1,
decision_repeat=5,
)
env = MultiGoalIntersectionEnv(config)
episode_rewards = defaultdict(float)
Expand Down

0 comments on commit 6a9a7b7

Please sign in to comment.