From ac2bc9fee6f48a0ae802f6604e6322b7a0f3cc43 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 6 Sep 2024 14:14:11 +0200 Subject: [PATCH] fix eval --- benchmarl/experiment/logger.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/benchmarl/experiment/logger.py b/benchmarl/experiment/logger.py index 740389fd..316ccad0 100644 --- a/benchmarl/experiment/logger.py +++ b/benchmarl/experiment/logger.py @@ -167,7 +167,8 @@ def log_evaluation( json_metrics = {} for group in self.group_map.keys(): # Cut the rollouts at the first done - for k, r in enumerate(rollouts): + rollouts_group = [] + for r in rollouts: next_done = self._get_done(group, r) # Reduce it to batch size next_done = next_done.sum( @@ -178,19 +179,21 @@ def log_evaluation( done_index = next_done.nonzero(as_tuple=True)[0] if done_index.numel() > 0: done_index = done_index[0] - rollouts[k] = r[: done_index + 1] + r = r[: done_index + 1] + rollouts_group.append(r) returns = [ - self._get_reward(group, td).sum(0).mean().item() for td in rollouts + self._get_reward(group, td).sum(0).mean().item() + for td in rollouts_group ] json_metrics[group + "_return"] = torch.tensor( - returns, device=rollouts[0].device + returns, device=rollouts_group[0].device ) to_log.update( { f"eval/{group}/reward/episode_reward_min": min(returns), f"eval/{group}/reward/episode_reward_mean": sum(returns) - / len(rollouts), + / len(rollouts_group), f"eval/{group}/reward/episode_reward_max": max(returns), } )