From 307456bc7185e7b0fd19f52e7c2759a2dba4b60e Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 13 Sep 2023 16:13:15 +0100 Subject: [PATCH] amend Signed-off-by: Matteo Bettini --- benchmarl/experiment/experiment.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 2d215f95..8cd16c8a 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -288,13 +288,7 @@ def _collection_loop(self): # Loop over groups training_start = time.time() for group in self.group_map.keys(): - group_batch = batch.exclude( - *[ - group_name - for group_name in self.group_map.keys() - if group_name != group - ] - ) + group_batch = batch.exclude(*self._get_excluded_keys(group)) group_batch = self.algorithm.process_batch(group, group_batch) group_batch = group_batch.reshape(-1) self.replay_buffers[group].extend(group_batch) @@ -341,7 +335,7 @@ def _collection_loop(self): # Evaluation if ( - self.config.evaluation_episodes > 0 + self.config.evaluation and self.n_iters_performed % self.config.evaluation_interval == 0 ): self._evaluation_loop(iter=self.n_iters_performed) @@ -352,6 +346,14 @@ def _collection_loop(self): self.collector.shutdown() + def _get_excluded_keys(self, group: str): + excluded_keys = [] + for other_group in self.group_map.keys(): + if other_group != group: + excluded_keys += [other_group, ("next", other_group)] + excluded_keys += [(group, "info"), ("next", group, "info")] + return excluded_keys + def _optimizer_loop(self, group: str) -> TensorDictBase: subdata = self.replay_buffers[group].sample() loss_vals = self.losses[group](subdata)