diff --git a/metadrive/component/pgblock/intersection.py b/metadrive/component/pgblock/intersection.py index 4c5a762f3..5e91407eb 100644 --- a/metadrive/component/pgblock/intersection.py +++ b/metadrive/component/pgblock/intersection.py @@ -268,3 +268,4 @@ def get_intermediate_spawn_lanes(self): class InterSectionWithUTurn(InterSection): ID = "U" _enable_u_turn_flag = True + SOCKET_NUM = 4 diff --git a/metadrive/envs/multigoal_intersection.py b/metadrive/envs/multigoal_intersection.py index f44cf091a..6c51a1ec7 100644 --- a/metadrive/envs/multigoal_intersection.py +++ b/metadrive/envs/multigoal_intersection.py @@ -8,18 +8,32 @@ from metadrive import MetaDriveEnv from metadrive.component.navigation_module.node_network_navigation import NodeNetworkNavigation +from metadrive.component.pg_space import ParameterSpace, Parameter, ConstantSpace, DiscreteSpace from metadrive.component.pgblock.first_block import FirstPGBlock from metadrive.component.pgblock.intersection import InterSectionWithUTurn from metadrive.component.road_network import Road from metadrive.constants import DEFAULT_AGENT from metadrive.engine.logger import get_logger -from metadrive.manager.base_manager import BaseManager - from metadrive.envs.varying_dynamics_env import VaryingDynamicsAgentManager, VaryingDynamicsConfig +from metadrive.manager.base_manager import BaseManager logger = get_logger() +class CustomizedIntersection(InterSectionWithUTurn): + PARAMETER_SPACE = ParameterSpace( + { + + # changed from 10 to 8: + Parameter.radius: ConstantSpace(8), + + # unchanged: + Parameter.change_lane_num: DiscreteSpace(min=0, max=1), + Parameter.decrease_increase: DiscreteSpace(min=0, max=1) + } + ) + + class MultiGoalIntersectionNavigationManager(BaseManager): """ This manager is responsible for managing multiple navigation modules, each of which is responsible for guiding the @@ -28,16 +42,16 @@ class MultiGoalIntersectionNavigationManager(BaseManager): GOALS = { "u_turn": (-Road(FirstPGBlock.NODE_2, FirstPGBlock.NODE_3)).end_node, "right_turn": Road( - InterSectionWithUTurn.node(block_idx=1, part_idx=0, road_idx=0), - InterSectionWithUTurn.node(block_idx=1, part_idx=0, road_idx=1) + CustomizedIntersection.node(block_idx=1, part_idx=0, road_idx=0), + CustomizedIntersection.node(block_idx=1, part_idx=0, road_idx=1) ).end_node, "go_straight": Road( - InterSectionWithUTurn.node(block_idx=1, part_idx=1, road_idx=0), - InterSectionWithUTurn.node(block_idx=1, part_idx=1, road_idx=1) + CustomizedIntersection.node(block_idx=1, part_idx=1, road_idx=0), + CustomizedIntersection.node(block_idx=1, part_idx=1, road_idx=1) ).end_node, "left_turn": Road( - InterSectionWithUTurn.node(block_idx=1, part_idx=2, road_idx=0), - InterSectionWithUTurn.node(block_idx=1, part_idx=2, road_idx=1) + CustomizedIntersection.node(block_idx=1, part_idx=2, road_idx=0), + CustomizedIntersection.node(block_idx=1, part_idx=2, road_idx=1) ).end_node, } @@ -91,7 +105,6 @@ class MultiGoalIntersectionEnv(MetaDriveEnv): This environment is an intersection with multiple goals. We provide the reward function, observation, termination conditions for each goal in the info dict returned by env.reset and env.step, with prefix "goals/{goal_name}/". """ - @classmethod def default_config(cls): config = MetaDriveEnv.default_config() @@ -101,14 +114,20 @@ def default_config(cls): # Set the map to an Intersection "start_seed": 0, - "map": "U", + + # Disable the shortcut config for map. + "map": None, + "map_config": dict( + type="block_sequence", config=[ + CustomizedIntersection, + ], lane_num=2, lane_width=3.5 + ), # Even though the map will not change, the traffic flow will change. "num_scenarios": 1000, # Remove all traffic vehicles for now. "traffic_density": 0.2, - "vehicle_config": { # Remove navigation arrows in the window as we are in multi-goal environment. @@ -256,14 +275,8 @@ def done_function(self, vehicle_id: str): use_render=True, manual_control=True, vehicle_config=dict(show_lidar=False, show_navi_mark=True, show_line_to_navi_mark=True), - map_config=dict( - type="block_sequence", - config="U", - lane_num=2, - lane_width=3.5 - ), accident_prob=1.0, - decision_repeat=1, + decision_repeat=5, ) env = MultiGoalIntersectionEnv(config) episode_rewards = defaultdict(float)