From b6c0a7994217397898cb216b401034f9c52aaf39 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 6 Oct 2023 15:11:41 +0100 Subject: [PATCH] update logger keys Signed-off-by: Matteo Bettini --- benchmarl/experiment/logger.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/benchmarl/experiment/logger.py b/benchmarl/experiment/logger.py index 2942a527..d32dd424 100644 --- a/benchmarl/experiment/logger.py +++ b/benchmarl/experiment/logger.py @@ -88,9 +88,9 @@ def log_collection( reward = self._get_reward(group, batch) to_log.update( { - f"collection/{group}/reward_min": reward.min().item(), - f"collection/{group}/reward_mean": reward.mean().item(), - f"collection/{group}/reward_max": reward.max().item(), + f"collection/{group}/reward/reward_min": reward.min().item(), + f"collection/{group}/reward/reward_mean": reward.mean().item(), + f"collection/{group}/reward/reward_max": reward.max().item(), } ) json_metrics[group + "_return"] = episode_reward.mean(-2)[done.any(-2)] @@ -98,9 +98,9 @@ def log_collection( if episode_reward.numel() > 0: to_log.update( { - f"collection/{group}/episode_reward_min": episode_reward.min().item(), - f"collection/{group}/episode_reward_mean": episode_reward.mean().item(), - f"collection/{group}/episode_reward_max": episode_reward.max().item(), + f"collection/{group}/reward/episode_reward_min": episode_reward.min().item(), + f"collection/{group}/reward/episode_reward_mean": episode_reward.mean().item(), + f"collection/{group}/reward/episode_reward_max": episode_reward.max().item(), } ) if "info" in batch.get(("next", group)).keys(): @@ -126,9 +126,9 @@ def log_collection( if mean_group_return.numel() > 0: to_log.update( { - "collection/episode_reward_min": mean_group_return.min().item(), - "collection/episode_reward_mean": mean_group_return.mean().item(), - "collection/episode_reward_max": mean_group_return.max().item(), + "collection/reward/episode_reward_min": mean_group_return.min().item(), + "collection/reward/episode_reward_mean": mean_group_return.mean().item(), + "collection/reward/episode_reward_max": mean_group_return.max().item(), } ) self.log(to_log, step=step) @@ -179,9 +179,10 @@ def log_evaluation( ) to_log.update( { - f"eval/{group}/episode_reward_min": min(returns), - f"eval/{group}/episode_reward_mean": sum(returns) / len(rollouts), - f"eval/{group}/episode_reward_max": max(returns), + f"eval/{group}/reward/episode_reward_min": min(returns), + f"eval/{group}/reward/episode_reward_mean": sum(returns) + / len(rollouts), + f"eval/{group}/reward/episode_reward_max": max(returns), } )