Skip to content

Commit

Permalink
[Refactor] RewardSum transform
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Dec 5, 2023
1 parent db6696d commit 615b3ba
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
6 changes: 5 additions & 1 deletion benchmarl/environments/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,11 @@ def get_reward_sum_transform(self, env: EnvBase) -> Transform:
Args:
env (EnvBase): An environment created via self.get_env_fun
"""
return RewardSum(reset_keys=env.reset_keys)
if "_reset" in env.reset_keys:
reset_keys = ["_reset"] * len(self.group_map(env).keys())
else:
reset_keys = env.reset_keys
return RewardSum(reset_keys=reset_keys)

@staticmethod
def render_callback(experiment, env: EnvBase, data: TensorDictBase):
Expand Down
9 changes: 1 addition & 8 deletions benchmarl/environments/pettingzoo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Callable, Dict, List, Optional

from torchrl.data import CompositeSpec
from torchrl.envs import EnvBase, PettingZooEnv, RewardSum, Transform
from torchrl.envs import EnvBase, PettingZooEnv

from benchmarl.environments.common import Task

Expand Down Expand Up @@ -147,13 +147,6 @@ def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
def action_spec(self, env: EnvBase) -> CompositeSpec:
return env.input_spec["full_action_spec"]

def get_reward_sum_transform(self, env: EnvBase) -> Transform:
if "_reset" in env.reset_keys:
reset_keys = ["_reset"] * len(self.group_map(env).keys())
else:
reset_keys = env.reset_keys
return RewardSum(reset_keys=reset_keys)

@staticmethod
def env_name() -> str:
return "pettingzoo"

0 comments on commit 615b3ba

Please sign in to comment.