diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py index f1a4196e..e742d808 100644 --- a/benchmarl/algorithms/common.py +++ b/benchmarl/algorithms/common.py @@ -41,6 +41,7 @@ def __init__(self, experiment): self.experiment = experiment self.device: DEVICE_TYPING = experiment.config.train_device + self.buffer_device: DEVICE_TYPING = experiment.config.buffer_device self.experiment_config = experiment.config self.model_config = experiment.model_config self.critic_model_config = experiment.critic_model_config @@ -141,11 +142,12 @@ def get_replay_buffer( """ memory_size = self.experiment_config.replay_buffer_memory_size(self.on_policy) sampling_size = self.experiment_config.train_minibatch_size(self.on_policy) - storing_device = self.device sampler = SamplerWithoutReplacement() if self.on_policy else RandomSampler() - return TensorDictReplayBuffer( - storage=LazyTensorStorage(memory_size, device=storing_device), + storage=LazyTensorStorage( + memory_size, + device=self.device if self.on_policy else self.buffer_device, + ), sampler=sampler, batch_size=sampling_size, priority_key=(group, "td_error"), diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index e84f39bb..3b87f759 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -6,6 +6,8 @@ defaults: sampling_device: "cpu" # The device for training (e.g. cuda) train_device: "cpu" +# The device for the replay buffer of off-policy algorithms (e.g. cuda) +buffer_device: "cpu" # Whether to share the parameters of the policy within agent groups share_policy_params: True diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index e74a97a0..5c400873 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -50,6 +50,7 @@ class ExperimentConfig: sampling_device: str = MISSING train_device: str = MISSING + buffer_device: str = MISSING share_policy_params: bool = MISSING prefer_continuous_actions: bool = MISSING @@ -462,9 +463,9 @@ def _setup_collector(self): storing_device=self.config.train_device, frames_per_batch=self.config.collected_frames_per_batch(self.on_policy), total_frames=self.config.get_max_n_frames(self.on_policy), - init_random_frames=self.config.off_policy_init_random_frames - if not self.on_policy - else 0, + init_random_frames=( + self.config.off_policy_init_random_frames if not self.on_policy else 0 + ), ) def _setup_name(self): @@ -647,7 +648,7 @@ def _get_excluded_keys(self, group: str): return excluded_keys def _optimizer_loop(self, group: str) -> TensorDictBase: - subdata = self.replay_buffers[group].sample() + subdata = self.replay_buffers[group].sample().to(self.config.train_device) loss_vals = self.losses[group](subdata) training_td = loss_vals.detach() loss_vals = self.algorithm.process_loss_vals(group, loss_vals) diff --git a/benchmarl/hydra_config.py b/benchmarl/hydra_config.py index 48a520ed..4366e5a5 100644 --- a/benchmarl/hydra_config.py +++ b/benchmarl/hydra_config.py @@ -19,7 +19,9 @@ from omegaconf import DictConfig, OmegaConf -def load_experiment_from_hydra(cfg: DictConfig, task_name: str) -> Experiment: +def load_experiment_from_hydra( + cfg: DictConfig, task_name: str, callbacks=() +) -> Experiment: """Creates an :class:`~benchmarl.experiment.Experiment` from hydra config. Args: @@ -43,6 +45,7 @@ def load_experiment_from_hydra(cfg: DictConfig, task_name: str) -> Experiment: critic_model_config=critic_model_config, seed=cfg.seed, config=experiment_config, + callbacks=callbacks, ) diff --git a/fine_tuned/smacv2/conf/config.yaml b/fine_tuned/smacv2/conf/config.yaml index feff9b7f..4c2d3a48 100644 --- a/fine_tuned/smacv2/conf/config.yaml +++ b/fine_tuned/smacv2/conf/config.yaml @@ -16,6 +16,7 @@ seed: 0 experiment: sampling_device: "cpu" train_device: "cuda" + buffer_device: "cuda" share_policy_params: True prefer_continuous_actions: True diff --git a/fine_tuned/vmas/conf/config.yaml b/fine_tuned/vmas/conf/config.yaml index dbc9f0a5..7dcf246e 100644 --- a/fine_tuned/vmas/conf/config.yaml +++ b/fine_tuned/vmas/conf/config.yaml @@ -17,6 +17,7 @@ experiment: sampling_device: "cuda" train_device: "cuda" + buffer_device: "cuda" share_policy_params: True prefer_continuous_actions: True