Skip to content

Commit

Permalink
[Feature] Init random batches
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 5, 2023
1 parent 3b03189 commit 87b2f6d
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 6 deletions.
6 changes: 4 additions & 2 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ share_policy_params: True
prefer_continuous_actions: True

# Discount factor
gamma: 0.99
gamma: 0.9
# Learning rate
lr: 0.00005
# Clips grad norm if true and clips grad value if false
Expand All @@ -35,6 +35,8 @@ exploration_eps_end: 0.01

# Number of frames collected and each experiment iteration
collected_frames_per_batch: 6000
# Number of initial collection batches containing random interactions
init_random_batches: 0
# Number of environments used for collection
# If the environment is vectorized, this will be the number of batched environments.
# Otherwise this batching will be simulated and each env will be run sequentially.
Expand Down Expand Up @@ -73,5 +75,5 @@ create_json: True
save_folder: null
# Absolute path to a checkpoint file where the experiment was saved. If null the experiment is started fresh.
restore_file: null
# Interval for experiment saving in terms of experiment iterations
# Interval for experiment saving in terms of experiment iterations. Set it to 0 to disable checkpointing
checkpoint_interval: 50
3 changes: 2 additions & 1 deletion benchmarl/environments/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchrl.data import CompositeSpec
from torchrl.envs import EnvBase

from benchmarl.utils import read_yaml_config
from benchmarl.utils import DEVICE_TYPING, read_yaml_config


def _load_config(name: str, config: Dict[str, Any]):
Expand Down Expand Up @@ -57,6 +57,7 @@ def get_env_fun(
num_envs: int,
continuous_actions: bool,
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
raise NotImplementedError

Expand Down
4 changes: 4 additions & 0 deletions benchmarl/environments/pettingzoo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from benchmarl.environments.common import Task

from benchmarl.utils import DEVICE_TYPING


class PettingZooTask(Task):
MULTIWALKER = None
Expand All @@ -15,12 +17,14 @@ def get_env_fun(
num_envs: int,
continuous_actions: bool,
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
if self.supports_continuous_actions() and self.supports_discrete_actions():
self.config.update({"continuous_actions": continuous_actions})

return lambda: PettingZooEnv(
categorical_actions=True,
device=device,
seed=seed,
parallel=True,
return_state=self.has_state(),
Expand Down
6 changes: 5 additions & 1 deletion benchmarl/environments/smacv2/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torchrl.envs.libs.smacv2 import SMACv2Env

from benchmarl.environments.common import Task
from benchmarl.utils import DEVICE_TYPING


class Smacv2Task(Task):
Expand All @@ -17,8 +18,11 @@ def get_env_fun(
num_envs: int,
continuous_actions: bool,
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
return lambda: SMACv2Env(categorical_actions=True, seed=seed, **self.config)
return lambda: SMACv2Env(
categorical_actions=True, seed=seed, device=device, **self.config
)

def supports_continuous_actions(self) -> bool:
return False
Expand Down
3 changes: 3 additions & 0 deletions benchmarl/environments/vmas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torchrl.envs.libs.vmas import VmasEnv

from benchmarl.environments.common import Task
from benchmarl.utils import DEVICE_TYPING


class VmasTask(Task):
Expand All @@ -17,12 +18,14 @@ def get_env_fun(
num_envs: int,
continuous_actions: bool,
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
return lambda: VmasEnv(
scenario=self.name.lower(),
num_envs=num_envs,
continuous_actions=continuous_actions,
seed=seed,
device=device,
categorical_actions=True,
**self.config,
)
Expand Down
18 changes: 16 additions & 2 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class ExperimentConfig:
exploration_eps_end: float = MISSING

collected_frames_per_batch: int = MISSING
init_random_batches: int = MISSING
n_envs_per_worker: int = MISSING
n_iters: int = MISSING
n_optimizer_steps: int = MISSING
Expand Down Expand Up @@ -106,6 +107,10 @@ def total_frames(self) -> int:
def exploration_anneal_frames(self) -> int:
return self.total_frames // 3

@property
def init_random_frames(self) -> int:
return self.init_random_batches * self.collected_frames_per_batch

@staticmethod
def get_from_yaml(path: Optional[str] = None):
if path is None:
Expand Down Expand Up @@ -191,13 +196,15 @@ def _setup_task(self):
num_envs=self.config.evaluation_episodes,
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
)
)()
env_func = self.model_config.process_env_fun(
self.task.get_env_fun(
num_envs=self.config.n_envs_per_worker,
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
)
)

Expand All @@ -219,7 +226,7 @@ def _setup_task(self):
else:
self.env_func = lambda: TransformedEnv(env_func(), transform.clone())

self.test_env = test_env
self.test_env = test_env.to(self.config.sampling_device)

def _setup_algorithm(self):
self.algorithm = self.algorithm_config.get_algorithm(
Expand Down Expand Up @@ -248,7 +255,7 @@ def _setup_algorithm(self):
}
self.optimizers = {
group: {
loss_name: torch.optim.Adam(params, lr=self.config.lr, eps=1e-4)
loss_name: torch.optim.Adam(params, lr=self.config.lr, eps=1e-6)
for loss_name, params in self.algorithm.get_parameters(group).items()
}
for group in self.group_map.keys()
Expand All @@ -270,6 +277,7 @@ def _setup_collector(self):
storing_device=self.config.train_device,
frames_per_batch=self.config.collected_frames_per_batch,
total_frames=self.config.total_frames,
init_random_frames=self.config.init_random_frames,
)

def _setup_name(self):
Expand All @@ -278,6 +286,12 @@ def _setup_name(self):
self.environment_name = self.task.env_name().lower()
self.task_name = self.task.name.lower()

if self.config.restore_file is not None and self.config.save_folder is not None:
raise ValueError(
"Experiment restore file and save folder have both been specified."
"Do not set a save_folder when you are reloading an experiment as"
"it will by default reloaded into the old folder."
)
if self.config.restore_file is None:
if self.config.save_folder is not None:
folder_name = Path(self.config.save_folder)
Expand Down

0 comments on commit 87b2f6d

Please sign in to comment.