Skip to content

Commit

Permalink
fix eval
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Sep 6, 2024
1 parent 8f84b67 commit ac2bc9f
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def log_evaluation(
json_metrics = {}
for group in self.group_map.keys():
# Cut the rollouts at the first done
for k, r in enumerate(rollouts):
rollouts_group = []
for r in rollouts:
next_done = self._get_done(group, r)
# Reduce it to batch size
next_done = next_done.sum(
Expand All @@ -178,19 +179,21 @@ def log_evaluation(
done_index = next_done.nonzero(as_tuple=True)[0]
if done_index.numel() > 0:
done_index = done_index[0]
rollouts[k] = r[: done_index + 1]
r = r[: done_index + 1]
rollouts_group.append(r)

returns = [
self._get_reward(group, td).sum(0).mean().item() for td in rollouts
self._get_reward(group, td).sum(0).mean().item()
for td in rollouts_group
]
json_metrics[group + "_return"] = torch.tensor(
returns, device=rollouts[0].device
returns, device=rollouts_group[0].device
)
to_log.update(
{
f"eval/{group}/reward/episode_reward_min": min(returns),
f"eval/{group}/reward/episode_reward_mean": sum(returns)
/ len(rollouts),
/ len(rollouts_group),
f"eval/{group}/reward/episode_reward_max": max(returns),
}
)
Expand Down

0 comments on commit ac2bc9f

Please sign in to comment.