Skip to content

Commit

Permalink
[Refactor] Introduce new EGreedyModule and compatibility with action …
Browse files Browse the repository at this point in the history
…mask

Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Sep 8, 2023
1 parent 0b3afdb commit 059934a
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 29 deletions.
2 changes: 1 addition & 1 deletion benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_policy_for_loss(self, group: str) -> TensorDictModule:
)
return self._policies_for_loss[group]

def get_policy_for_collection(self) -> TensorDictModule:
def get_policy_for_collection(self) -> TensorDictSequential:
policies = []
for group in self.group_map.keys():
if group not in self._policies_for_collection.keys():
Expand Down
19 changes: 13 additions & 6 deletions benchmarl/algorithms/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from torchrl.data.replay_buffers.samplers import PrioritizedSampler
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.modules import EGreedyWrapper, QValueModule
from torchrl.modules import EGreedyModule, QValueModule
from torchrl.objectives import ClipPPOLoss, DQNLoss, LossModule, ValueEstimators
from torchrl.objectives.utils import SoftUpdate, TargetNetUpdater

Expand Down Expand Up @@ -134,11 +134,13 @@ def _get_policy_for_loss(
device=self.device,
)
if self.action_mask_spec is not None:
raise NotImplementedError(
"action mask is not yet compatible with q value modules"
)
action_mask_key = (group, "action_mask")
else:
action_mask_key = None

value_module = QValueModule(
action_value_key=(group, "action_value"),
action_mask_key=action_mask_key,
out_keys=[
(group, "action"),
(group, "action_value"),
Expand All @@ -153,15 +155,20 @@ def _get_policy_for_loss(
def _get_policy_for_collection(
self, policy_for_loss: TensorDictModule, group: str, continuous: bool
) -> TensorDictModule:
if self.action_mask_spec is not None:
action_mask_key = (group, "action_mask")
else:
action_mask_key = None

return EGreedyWrapper(
policy_for_loss,
greedy = EGreedyModule(
annealing_num_steps=self.experiment_config.exploration_annealing_num_frames,
action_key=(group, "action"),
spec=self.action_spec[(group, "action")],
action_mask_key=action_mask_key,
# eps_init = 1.0,
# eps_end = 0.1,
)
return TensorDictSequential(*policy_for_loss, greedy)

def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
keys = list(batch.keys(True, True))
Expand Down
19 changes: 13 additions & 6 deletions benchmarl/algorithms/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from torchrl.data.replay_buffers.samplers import PrioritizedSampler
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.modules import EGreedyWrapper, QMixer, QValueModule
from torchrl.modules import EGreedyModule, QMixer, QValueModule
from torchrl.objectives import ClipPPOLoss, LossModule, QMixerLoss, ValueEstimators
from torchrl.objectives.utils import SoftUpdate, TargetNetUpdater

Expand Down Expand Up @@ -142,11 +142,13 @@ def _get_policy_for_loss(
device=self.device,
)
if self.action_mask_spec is not None:
raise NotImplementedError(
"action mask is not yet compatible with q value modules"
)
action_mask_key = (group, "action_mask")
else:
action_mask_key = None

value_module = QValueModule(
action_value_key=(group, "action_value"),
action_mask_key=action_mask_key,
out_keys=[
(group, "action"),
(group, "action_value"),
Expand All @@ -161,15 +163,20 @@ def _get_policy_for_loss(
def _get_policy_for_collection(
self, policy_for_loss: TensorDictModule, group: str, continuous: bool
) -> TensorDictModule:
if self.action_mask_spec is not None:
action_mask_key = (group, "action_mask")
else:
action_mask_key = None

return EGreedyWrapper(
policy_for_loss,
greedy = EGreedyModule(
annealing_num_steps=self.experiment_config.exploration_annealing_num_frames,
action_key=(group, "action"),
spec=self.action_spec[(group, "action")],
action_mask_key=action_mask_key,
# eps_init = 1.0,
# eps_end = 0.1,
)
return TensorDictSequential(*policy_for_loss, greedy)

def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
keys = list(batch.keys(True, True))
Expand Down
19 changes: 13 additions & 6 deletions benchmarl/algorithms/vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from torchrl.data.replay_buffers.samplers import PrioritizedSampler
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.modules import EGreedyWrapper, QValueModule, VDNMixer
from torchrl.modules import EGreedyModule, QValueModule, VDNMixer
from torchrl.objectives import ClipPPOLoss, LossModule, QMixerLoss, ValueEstimators
from torchrl.objectives.utils import SoftUpdate, TargetNetUpdater

Expand Down Expand Up @@ -136,11 +136,13 @@ def _get_policy_for_loss(
device=self.device,
)
if self.action_mask_spec is not None:
raise NotImplementedError(
"action mask is not yet compatible with q value modules"
)
action_mask_key = (group, "action_mask")
else:
action_mask_key = None

value_module = QValueModule(
action_value_key=(group, "action_value"),
action_mask_key=action_mask_key,
out_keys=[
(group, "action"),
(group, "action_value"),
Expand All @@ -155,15 +157,20 @@ def _get_policy_for_loss(
def _get_policy_for_collection(
self, policy_for_loss: TensorDictModule, group: str, continuous: bool
) -> TensorDictModule:
if self.action_mask_spec is not None:
action_mask_key = (group, "action_mask")
else:
action_mask_key = None

return EGreedyWrapper(
policy_for_loss,
greedy = EGreedyModule(
annealing_num_steps=self.experiment_config.exploration_annealing_num_frames,
action_key=(group, "action"),
spec=self.action_spec[(group, "action")],
action_mask_key=action_mask_key,
# eps_init = 1.0,
# eps_end = 0.1,
)
return TensorDictSequential(*policy_for_loss, greedy)

def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
keys = list(batch.keys(True, True))
Expand Down
18 changes: 16 additions & 2 deletions benchmarl/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass, MISSING
from typing import Optional

from tensordict.nn import TensorDictSequential
from torchrl.collectors import SyncDataCollector

from benchmarl.algorithms.common import AlgorithmConfig
Expand Down Expand Up @@ -175,6 +176,13 @@ def _setup_algorithm(self):

def _setup_collector(self):
self.policy = self.algorithm.get_policy_for_collection()

self.group_policies = {}
for group in self.group_map.keys():
group_policy = self.policy.select_subsequence(out_keys=[(group, "action")])
assert len(group_policy) == 1
self.group_policies.update({group: group_policy[0]})

self.collector = SyncDataCollector(
self.env,
self.policy,
Expand Down Expand Up @@ -219,8 +227,14 @@ def run(self):
assert False
if self.target_updaters[group] is not None:
self.target_updaters[group].step()
if hasattr(self.policy, "step"): # Step exploration annealing
self.policy.step(current_frames)

if isinstance(self.group_policies[group], TensorDictSequential):
explore_layer = self.group_policies[group][-1]
else:
explore_layer = self.group_policies[group]
if hasattr(explore_layer, "step"): # Step exploration annealing
explore_layer.step(current_frames)

self.collector.update_policy_weights_()

self.n_iters_performed += 1
Expand Down
3 changes: 1 addition & 2 deletions benchmarl/hydra_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig

from benchmarl import algorithm_config_registry

from benchmarl.algorithms import algorithm_config_registry
from benchmarl.environments import task_config_registry
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models.common import ModelConfig
Expand Down
16 changes: 10 additions & 6 deletions test/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@

from benchmarl.algorithms import algorithm_config_registry
from benchmarl.environments import VmasTask
from benchmarl.experiment import (
Experiment,
ExperimentConfig,
load_experiment_from_hydra_config,
)
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.hydra_run import load_experiment_from_hydra_config
from benchmarl.models.common import SequenceModelConfig
from benchmarl.models.mlp import MlpConfig
from hydra import compose, initialize
Expand Down Expand Up @@ -47,7 +44,14 @@ def test_all_algos_hydra(algo_config):
)
task_name = cfg.hydra.runtime.choices.task
algo_name = cfg.hydra.runtime.choices.algorithm
model_config = SequenceModelConfig(
model_configs=[
MlpConfig(num_cells=[8]),
MlpConfig(num_cells=[4]),
],
intermediate_sizes=[5],
)
experiment = load_experiment_from_hydra_config(
cfg, algo_name=algo_name, task_name=task_name
cfg, algo_name=algo_name, task_name=task_name, model_config=model_config
)
experiment.run()

0 comments on commit 059934a

Please sign in to comment.