Skip to content

Commit

Permalink
pass entire experiment to model
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Oct 20, 2023
1 parent 0b8436c commit 0f299ca
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 23 deletions.
2 changes: 2 additions & 0 deletions benchmarl/algorithms/iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _get_policy_for_loss(
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
experiment=self.experiment,
)

policy = ProbabilisticActor(
Expand Down Expand Up @@ -217,6 +218,7 @@ def get_value_module(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
experiment=self.experiment,
)
)

Expand Down
5 changes: 2 additions & 3 deletions benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(
def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:

# Loss
loss_module = ClipPPOLoss(
actor=policy_for_loss,
Expand Down Expand Up @@ -83,7 +82,6 @@ def _get_parameters(self, group: str, loss: ClipPPOLoss) -> Dict[str, Iterable]:
def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
) -> TensorDictModule:

n_agents = len(self.group_map[group])
if continuous:
logits_shape = list(self.action_spec[group, "action"].shape)
Expand Down Expand Up @@ -124,6 +122,7 @@ def _get_policy_for_loss(
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
experiment=self.experiment,
)

if continuous:
Expand Down Expand Up @@ -261,14 +260,14 @@ def get_critic(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
experiment=self.experiment,
)

return value_module


@dataclass
class IppoConfig(AlgorithmConfig):

share_param_critic: bool = MISSING
clip_epsilon: float = MISSING
entropy_coef: float = MISSING
Expand Down
3 changes: 1 addition & 2 deletions benchmarl/algorithms/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
) -> TensorDictModule:

n_agents = len(self.group_map[group])
logits_shape = [
*self.action_spec[group, "action"].shape,
Expand Down Expand Up @@ -99,6 +98,7 @@ def _get_policy_for_loss(
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
experiment=self.experiment,
)
if self.action_mask_spec is not None:
action_mask_key = (group, "action_mask")
Expand Down Expand Up @@ -175,7 +175,6 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:

@dataclass
class IqlConfig(AlgorithmConfig):

delay_value: bool = MISSING
loss_function: str = MISSING

Expand Down
5 changes: 3 additions & 2 deletions benchmarl/algorithms/isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def _get_parameters(self, group: str, loss: ClipPPOLoss) -> Dict[str, Iterable]:
def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
) -> TensorDictModule:

n_agents = len(self.group_map[group])
if continuous:
logits_shape = list(self.action_spec[group, "action"].shape)
Expand Down Expand Up @@ -167,6 +166,7 @@ def _get_policy_for_loss(
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
experiment=self.experiment,
)

if continuous:
Expand Down Expand Up @@ -291,6 +291,7 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
experiment=self.experiment,
)

return value_module
Expand Down Expand Up @@ -346,6 +347,7 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
experiment=self.experiment,
)
)

Expand All @@ -354,7 +356,6 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:

@dataclass
class IsacConfig(AlgorithmConfig):

share_param_critic: bool = MISSING

num_qvalue_nets: int = MISSING
Expand Down
6 changes: 3 additions & 3 deletions benchmarl/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def _get_loss(
)

def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:

return {
"loss_actor": list(loss.actor_network_params.flatten_keys().values()),
"loss_value": list(loss.value_network_params.flatten_keys().values()),
Expand Down Expand Up @@ -103,6 +102,7 @@ def _get_policy_for_loss(
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
experiment=self.experiment,
)

policy = ProbabilisticActor(
Expand Down Expand Up @@ -222,11 +222,11 @@ def get_value_module(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
experiment=self.experiment,
)
)

else:

modules.append(
TensorDictModule(
lambda obs, action: torch.cat([obs, action], dim=-1),
Expand Down Expand Up @@ -263,6 +263,7 @@ def get_value_module(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
experiment=self.experiment,
)
)

Expand All @@ -282,7 +283,6 @@ def get_value_module(self, group: str) -> TensorDictModule:

@dataclass
class MaddpgConfig(AlgorithmConfig):

share_param_critic: bool = MISSING

loss_function: str = MISSING
Expand Down
3 changes: 3 additions & 0 deletions benchmarl/algorithms/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def _get_policy_for_loss(
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
experiment=self.experiment,
)

if continuous:
Expand Down Expand Up @@ -258,6 +259,7 @@ def get_critic(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
experiment=self.experiment,
)

else:
Expand All @@ -282,6 +284,7 @@ def get_critic(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
experiment=self.experiment,
)
if self.share_param_critic:
expand_module = TensorDictModule(
Expand Down
8 changes: 5 additions & 3 deletions benchmarl/algorithms/masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
) -> TensorDictModule:

n_agents = len(self.group_map[group])
if continuous:
logits_shape = list(self.action_spec[group, "action"].shape)
Expand Down Expand Up @@ -162,6 +161,7 @@ def _get_policy_for_loss(
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
experiment=self.experiment,
)

if continuous:
Expand Down Expand Up @@ -279,6 +279,7 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
experiment=self.experiment,
)

else:
Expand All @@ -303,6 +304,7 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
experiment=self.experiment,
)
if self.share_param_critic:
expand_module = TensorDictModule(
Expand Down Expand Up @@ -369,11 +371,11 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
experiment=self.experiment,
)
)

else:

modules.append(
TensorDictModule(
lambda obs, action: torch.cat([obs, action], dim=-1),
Expand Down Expand Up @@ -410,6 +412,7 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
experiment=self.experiment,
)
)

Expand All @@ -429,7 +432,6 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:

@dataclass
class MasacConfig(AlgorithmConfig):

share_param_critic: bool = MISSING

num_qvalue_nets: int = MISSING
Expand Down
4 changes: 1 addition & 3 deletions benchmarl/algorithms/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
) -> TensorDictModule:

n_agents = len(self.group_map[group])
logits_shape = [
*self.action_spec[group, "action"].shape,
Expand Down Expand Up @@ -104,6 +103,7 @@ def _get_policy_for_loss(
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
experiment=self.experiment,
)
if self.action_mask_spec is not None:
action_mask_key = (group, "action_mask")
Expand Down Expand Up @@ -175,7 +175,6 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
#####################

def get_mixer(self, group: str) -> TensorDictModule:

n_agents = len(self.group_map[group])

if self.state_spec is not None:
Expand All @@ -201,7 +200,6 @@ def get_mixer(self, group: str) -> TensorDictModule:

@dataclass
class QmixConfig(AlgorithmConfig):

mixing_embed_dim: int = MISSING
delay_value: bool = MISSING
loss_function: str = MISSING
Expand Down
5 changes: 1 addition & 4 deletions benchmarl/algorithms/vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,13 @@ def _get_loss(
return loss_module, True

def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:

return {
"loss": loss.parameters(),
}

def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
) -> TensorDictModule:

n_agents = len(self.group_map[group])
logits_shape = [
*self.action_spec[group, "action"].shape,
Expand Down Expand Up @@ -104,6 +102,7 @@ def _get_policy_for_loss(
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
experiment=self.experiment,
)
if self.action_mask_spec is not None:
action_mask_key = (group, "action_mask")
Expand Down Expand Up @@ -175,7 +174,6 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
#####################

def get_mixer(self, group: str) -> TensorDictModule:

n_agents = len(self.group_map[group])
mixer = TensorDictModule(
module=VDNMixer(
Expand All @@ -191,7 +189,6 @@ def get_mixer(self, group: str) -> TensorDictModule:

@dataclass
class VdnConfig(AlgorithmConfig):

delay_value: bool = MISSING
loss_function: str = MISSING

Expand Down
Loading

0 comments on commit 0f299ca

Please sign in to comment.