Skip to content

Commit

Permalink
Support for custom observation function, minor changes to my previous…
Browse files Browse the repository at this point in the history
… implementation (#126)

* Minor style changes to my previous implementation (to match the author's style)

* Support for custom observation function, preparation for drq_norm

* Working support for registerable observation functions, minor implementation changes

* Moved checking the reward function to TrafficSignal constructor
  • Loading branch information
firemankoxd authored Dec 7, 2022
1 parent 4cfc04e commit d95389d
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 57 deletions.
2 changes: 1 addition & 1 deletion sumo_rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sumo_rl.environment.env import SumoEnvironment
from sumo_rl.environment.env import SumoEnvironment, TrafficSignal
from sumo_rl.environment.env import env, parallel_env
from sumo_rl.environment.resco_envs import grid4x4, arterial4x4, ingolstadt1, ingolstadt7, ingolstadt21, cologne1, cologne3, cologne8
67 changes: 30 additions & 37 deletions sumo_rl/environment/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class SumoEnvironment(gym.Env):
:param max_green: (int) Max green time in a phase
:single_agent: (bool) If true, it behaves like a regular gym.Env. Else, it behaves like a MultiagentEnv (https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/multi_agent_env.py)
:reward_fn: (str/function/dict) String with the name of the reward function used by the agents, a reward function, or dictionary with reward functions assigned to individual traffic lights by their keys
:observation_fn: (str/function) String with the name of the observation function or a callable observation function itself
:add_system_info: (bool) If true, it computes system metrics (total queue, total waiting time, average speed) in the info dictionary
:add_per_agent_info: (bool) If true, it computes per-agent (per-traffic signal) metrics (average accumulated waiting time, average queue) in the info dictionary
:sumo_seed: (int/string) Random seed for sumo. If 'random' it uses a randomly chosen seed.
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(
max_green: int = 50,
single_agent: bool = False,
reward_fn: Union[str,Callable,dict] = 'diff-waiting-time',
observation_fn: Union[str,Callable] = 'default',
add_system_info: bool = True,
add_per_agent_info: bool = True,
sumo_seed: Union[str,int] = 'random',
Expand Down Expand Up @@ -133,33 +135,28 @@ def __init__(
traci.start([sumolib.checkBinary('sumo'), '-n', self._net], label='init_connection'+self.label)
conn = traci.getConnection('init_connection'+self.label)
self.ts_ids = list(conn.trafficlight.getIDList())
self.observation_fn = observation_fn

if isinstance(self.reward_fn, dict):
self.traffic_signals = dict()
for key, reward_fn_value in self.reward_fn.items():
self.traffic_signals[key] = TrafficSignal(
self,
key,
self.delta_time,
self.yellow_time,
self.min_green,
self.max_green,
self.begin_time,
reward_fn_value,
conn
)
self.traffic_signals = {ts: TrafficSignal(self,
ts,
self.delta_time,
self.yellow_time,
self.min_green,
self.max_green,
self.begin_time,
self.reward_fn[ts],
conn) for ts in self.reward_fn.keys()}
else:
self.traffic_signals = {
ts: TrafficSignal(self,
ts,
self.delta_time,
self.yellow_time,
self.min_green,
self.max_green,
self.begin_time,
self.reward_fn,
conn) for ts in self.ts_ids
}
self.traffic_signals = {ts: TrafficSignal(self,
ts,
self.delta_time,
self.yellow_time,
self.min_green,
self.max_green,
self.begin_time,
self.reward_fn,
conn) for ts in self.ts_ids}

conn.close()

Expand Down Expand Up @@ -223,19 +220,15 @@ def reset(self, seed: Optional[int] = None, **kwargs):
self._start_simulation()

if isinstance(self.reward_fn, dict):
self.traffic_signals = dict()
for key, reward_fn_value in self.reward_fn.items():
self.traffic_signals[key] = TrafficSignal(
self,
key,
self.delta_time,
self.yellow_time,
self.min_green,
self.max_green,
self.begin_time,
reward_fn_value,
self.sumo
)
self.traffic_signals = {ts: TrafficSignal(self,
ts,
self.delta_time,
self.yellow_time,
self.min_green,
self.max_green,
self.begin_time,
self.reward_fn[ts],
self.sumo) for ts in self.reward_fn.keys()}
else:
self.traffic_signals = {ts: TrafficSignal(self,
ts,
Expand Down
68 changes: 49 additions & 19 deletions sumo_rl/environment/traffic_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def __init__(self,
self.reward_fn = reward_fn
self.sumo = sumo

if type(self.reward_fn) is str:
if self.reward_fn in TrafficSignal.reward_fns.keys():
self.reward_fn = TrafficSignal.reward_fns[self.reward_fn]
else:
raise NotImplementedError(f'Reward function {self.reward_fn} not implemented')

if isinstance(self.env.observation_fn, Callable):
self.observation_fn = self.env.observation_fn
else:
if self.env.observation_fn in TrafficSignal.observation_fns.keys():
self.observation_fn = TrafficSignal.observation_fns[self.env.observation_fn]
else:
raise NotImplementedError(f'Observation function {self.env.observation_fn} not implemented')

self.build_phases()

self.lanes = list(dict.fromkeys(self.sumo.trafficlight.getControlledLanes(self.id))) # Remove duplicates and keep order
Expand Down Expand Up @@ -134,27 +148,10 @@ def set_next_phase(self, new_phase):
self.time_since_last_phase_change = 0

def compute_observation(self):
phase_id = [1 if self.green_phase == i else 0 for i in range(self.num_green_phases)] # one-hot encoding
min_green = [0 if self.time_since_last_phase_change < self.min_green + self.yellow_time else 1]
density = self.get_lanes_density()
queue = self.get_lanes_queue()
observation = np.array(phase_id + min_green + density + queue, dtype=np.float32)
return observation
return self.observation_fn(self)

def compute_reward(self):
if type(self.reward_fn) is str:
if self.reward_fn == 'diff-waiting-time':
self.last_reward = self._diff_waiting_time_reward()
elif self.reward_fn == 'average-speed':
self.last_reward = self._average_speed_reward()
elif self.reward_fn == 'queue':
self.last_reward = self._queue_reward()
elif self.reward_fn == 'pressure':
self.last_reward = self._pressure_reward()
else:
raise NotImplementedError(f'Reward function {self.reward_fn} not implemented')
else:
self.last_reward = self.reward_fn(self)
self.last_reward = self.reward_fn(self)
return self.last_reward

def _pressure_reward(self):
Expand All @@ -172,6 +169,14 @@ def _diff_waiting_time_reward(self):
self.last_measure = ts_wait
return reward

def _observation_fn_default(self):
phase_id = [1 if self.green_phase == i else 0 for i in range(self.num_green_phases)] # one-hot encoding
min_green = [0 if self.time_since_last_phase_change < self.min_green + self.yellow_time else 1]
density = self.get_lanes_density()
queue = self.get_lanes_queue()
observation = np.array(phase_id + min_green + density + queue, dtype=np.float32)
return observation

def get_accumulated_waiting_time_per_lane(self):
wait_time_per_lane = []
for lane in self.lanes:
Expand Down Expand Up @@ -220,3 +225,28 @@ def _get_veh_list(self):
for lane in self.lanes:
veh_list += self.sumo.lane.getLastStepVehicleIDs(lane)
return veh_list

@classmethod
def register_reward_fn(cls, fn):
if fn.__name__ in cls.reward_fns.keys():
raise KeyError(f'Reward function {fn.__name__} already exists')

cls.reward_fns[fn.__name__] = fn

@classmethod
def register_observation_fn(cls, fn):
if fn.__name__ in cls.observation_fns.keys():
raise KeyError(f'Observation function {fn.__name__} already exists')

cls.observation_fns[fn.__name__] = fn

reward_fns = {
'diff-waiting-time': _diff_waiting_time_reward,
'average-speed': _average_speed_reward,
'queue': _queue_reward,
'pressure': _pressure_reward
}

observation_fns = {
'default': _observation_fn_default
}

0 comments on commit d95389d

Please sign in to comment.