Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Sep 13, 2023
1 parent 05f775e commit 307456b
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,7 @@ def _collection_loop(self):
# Loop over groups
training_start = time.time()
for group in self.group_map.keys():
group_batch = batch.exclude(
*[
group_name
for group_name in self.group_map.keys()
if group_name != group
]
)
group_batch = batch.exclude(*self._get_excluded_keys(group))
group_batch = self.algorithm.process_batch(group, group_batch)
group_batch = group_batch.reshape(-1)
self.replay_buffers[group].extend(group_batch)
Expand Down Expand Up @@ -341,7 +335,7 @@ def _collection_loop(self):

# Evaluation
if (
self.config.evaluation_episodes > 0
self.config.evaluation
and self.n_iters_performed % self.config.evaluation_interval == 0
):
self._evaluation_loop(iter=self.n_iters_performed)
Expand All @@ -352,6 +346,14 @@ def _collection_loop(self):

self.collector.shutdown()

def _get_excluded_keys(self, group: str):
excluded_keys = []
for other_group in self.group_map.keys():
if other_group != group:
excluded_keys += [other_group, ("next", other_group)]
excluded_keys += [(group, "info"), ("next", group, "info")]
return excluded_keys

def _optimizer_loop(self, group: str) -> TensorDictBase:
subdata = self.replay_buffers[group].sample()
loss_vals = self.losses[group](subdata)
Expand Down

0 comments on commit 307456b

Please sign in to comment.