Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jul 31, 2024
1 parent 6edcdc1 commit 88c3bc8
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 36 deletions.
14 changes: 8 additions & 6 deletions benchmarl/algorithms/iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,14 @@ def _get_policy_for_loss(
in_keys=[(group, "param")],
out_keys=[(group, "action")],
distribution_class=TanhDelta if self.use_tanh_mapping else Delta,
distribution_kwargs={
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_mapping
else {},
distribution_kwargs=(
{
"low": self.action_spec[(group, "action")].space.low,
"high": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_mapping
else {}
),
return_log_prob=False,
safe=not self.use_tanh_mapping,
)
Expand Down
4 changes: 2 additions & 2 deletions benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def _get_policy_for_loss(
),
distribution_kwargs=(
{
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
"low": self.action_spec[(group, "action")].space.low,
"high": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {}
Expand Down
20 changes: 11 additions & 9 deletions benchmarl/algorithms/isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,17 @@ def _get_policy_for_loss(
spec=self.action_spec[group, "action"],
in_keys=[(group, "loc"), (group, "scale")],
out_keys=[(group, "action")],
distribution_class=IndependentNormal
if not self.use_tanh_normal
else TanhNormal,
distribution_kwargs={
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {},
distribution_class=(
IndependentNormal if not self.use_tanh_normal else TanhNormal
),
distribution_kwargs=(
{
"low": self.action_spec[(group, "action")].space.low,
"high": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {}
),
return_log_prob=True,
log_prob_key=(group, "log_prob"),
)
Expand Down
14 changes: 8 additions & 6 deletions benchmarl/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,14 @@ def _get_policy_for_loss(
in_keys=[(group, "param")],
out_keys=[(group, "action")],
distribution_class=TanhDelta if self.use_tanh_mapping else Delta,
distribution_kwargs={
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_mapping
else {},
distribution_kwargs=(
{
"low": self.action_spec[(group, "action")].space.low,
"high": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_mapping
else {}
),
return_log_prob=False,
safe=not self.use_tanh_mapping,
)
Expand Down
4 changes: 2 additions & 2 deletions benchmarl/algorithms/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def _get_policy_for_loss(
),
distribution_kwargs=(
{
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
"low": self.action_spec[(group, "action")].space.low,
"high": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {}
Expand Down
20 changes: 11 additions & 9 deletions benchmarl/algorithms/masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,17 @@ def _get_policy_for_loss(
spec=self.action_spec[group, "action"],
in_keys=[(group, "loc"), (group, "scale")],
out_keys=[(group, "action")],
distribution_class=IndependentNormal
if not self.use_tanh_normal
else TanhNormal,
distribution_kwargs={
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {},
distribution_class=(
IndependentNormal if not self.use_tanh_normal else TanhNormal
),
distribution_kwargs=(
{
"low": self.action_spec[(group, "action")].space.low,
"high": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {}
),
return_log_prob=True,
log_prob_key=(group, "log_prob"),
)
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def _grad_clip(self, optimizer: torch.optim.Optimizer) -> float:
def _evaluation_loop(self):
evaluation_start = time.time()
with set_exploration_type(
ExplorationType.MODE
ExplorationType.DETERMINISTIC
if self.config.evaluation_deterministic_actions
else ExplorationType.RANDOM
):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_version():
url="https://github.com/facebookresearch/BenchMARL",
author="Matteo Bettini",
author_email="[email protected]",
install_requires=["torchrl>=0.4.0", "tqdm", "hydra-core"],
install_requires=["torchrl>=0.5.0", "tqdm", "hydra-core"],
extras_require={
"vmas": ["vmas>=1.3.4"],
"pettingzoo": ["pettingzoo[all]>=1.24.3"],
Expand Down

0 comments on commit 88c3bc8

Please sign in to comment.