Skip to content

Commit

Permalink
update logger keys
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 6, 2023
1 parent 3f43db6 commit b6c0a79
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,19 @@ def log_collection(
reward = self._get_reward(group, batch)
to_log.update(
{
f"collection/{group}/reward_min": reward.min().item(),
f"collection/{group}/reward_mean": reward.mean().item(),
f"collection/{group}/reward_max": reward.max().item(),
f"collection/{group}/reward/reward_min": reward.min().item(),
f"collection/{group}/reward/reward_mean": reward.mean().item(),
f"collection/{group}/reward/reward_max": reward.max().item(),
}
)
json_metrics[group + "_return"] = episode_reward.mean(-2)[done.any(-2)]
episode_reward = episode_reward[done]
if episode_reward.numel() > 0:
to_log.update(
{
f"collection/{group}/episode_reward_min": episode_reward.min().item(),
f"collection/{group}/episode_reward_mean": episode_reward.mean().item(),
f"collection/{group}/episode_reward_max": episode_reward.max().item(),
f"collection/{group}/reward/episode_reward_min": episode_reward.min().item(),
f"collection/{group}/reward/episode_reward_mean": episode_reward.mean().item(),
f"collection/{group}/reward/episode_reward_max": episode_reward.max().item(),
}
)
if "info" in batch.get(("next", group)).keys():
Expand All @@ -126,9 +126,9 @@ def log_collection(
if mean_group_return.numel() > 0:
to_log.update(
{
"collection/episode_reward_min": mean_group_return.min().item(),
"collection/episode_reward_mean": mean_group_return.mean().item(),
"collection/episode_reward_max": mean_group_return.max().item(),
"collection/reward/episode_reward_min": mean_group_return.min().item(),
"collection/reward/episode_reward_mean": mean_group_return.mean().item(),
"collection/reward/episode_reward_max": mean_group_return.max().item(),
}
)
self.log(to_log, step=step)
Expand Down Expand Up @@ -179,9 +179,10 @@ def log_evaluation(
)
to_log.update(
{
f"eval/{group}/episode_reward_min": min(returns),
f"eval/{group}/episode_reward_mean": sum(returns) / len(rollouts),
f"eval/{group}/episode_reward_max": max(returns),
f"eval/{group}/reward/episode_reward_min": min(returns),
f"eval/{group}/reward/episode_reward_mean": sum(returns)
/ len(rollouts),
f"eval/{group}/reward/episode_reward_max": max(returns),
}
)

Expand Down

0 comments on commit b6c0a79

Please sign in to comment.