diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index f2df25c3..0ef9ee98 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -2,8 +2,8 @@ defaults: - experiment_config - _self_ -sampling_device: "cuda" -train_device: "cuda" +sampling_device: "cpu" +train_device: "cpu" gamma: 0.99 polyak_tau: 0.005 lr: 0.00005 @@ -28,3 +28,6 @@ evaluation_episodes: 200 loggers: [wandb] create_json: True + +restore_file: null +checkpoint_interval: 0 diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 5f9d0b3a..5cc5f104 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -1,7 +1,12 @@ -import pathlib +from __future__ import annotations + +import importlib +import os import time +from collections import OrderedDict from dataclasses import dataclass, MISSING -from typing import List, Optional +from pathlib import Path +from typing import Dict, List, Optional import torch @@ -12,6 +17,7 @@ from torchrl.envs import EnvBase, RewardSum, SerialEnv, TransformedEnv from torchrl.envs.transforms import Compose from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.record.loggers import generate_exp_name from tqdm import tqdm from benchmarl.algorithms.common import AlgorithmConfig @@ -20,6 +26,10 @@ from benchmarl.models.common import ModelConfig from benchmarl.utils import read_yaml_config +_has_hydra = importlib.util.find_spec("hydra") is not None +if _has_hydra: + from hydra.core.hydra_config import HydraConfig + @dataclass class ExperimentConfig: @@ -51,6 +61,9 @@ class ExperimentConfig: loggers: List[str] = MISSING create_json: bool = MISSING + restore_file: Optional[str] = MISSING + checkpoint_interval: float = MISSING + def train_batch_size(self, on_policy: bool) -> int: return ( self.collected_frames_per_batch @@ -88,7 +101,7 @@ def exploration_annealing_num_frames(self) -> int: def get_from_yaml(path: Optional[str] = None): if path is None: yaml_path = ( - pathlib.Path(__file__).parent.parent + Path(__file__).parent.parent / "conf" / "experiment" / "base_experiment.yaml" @@ -121,6 +134,9 @@ def __init__( self.n_iters_performed = 0 self.mean_return = 0 + if self.config.restore_file is not None: + self.load_trainer() + @property def on_policy(self) -> bool: return self.algorithm_config.on_policy() @@ -130,6 +146,7 @@ def _setup(self): self._setup_task() self._setup_algorithm() self._setup_collector() + self._setup_name() self._setup_logger() def _set_action_type(self): @@ -251,13 +268,37 @@ def _setup_collector(self): total_frames=self.config.total_frames, ) + def _setup_name(self): + self.algorithm_name = self.algorithm_config.associated_class().__name__.lower() + self.model_name = self.model_config.associated_class().__name__.lower() + self.environment_name = self.task.env_name().lower() + self.task_name = self.task.name.lower() + + if self.config.restore_file is None: + if _has_hydra and HydraConfig.initialized(): + folder_name = Path(HydraConfig.get().runtime.output_dir) + else: + folder_name = Path(os.getcwd()) + self.name = generate_exp_name( + f"{self.algorithm_name}_{self.task_name}_{self.model_name}", "" + ) + self.folder_name = folder_name / self.name + self.folder_name.mkdir(parents=False, exist_ok=False) + + else: + self.folder_name = Path(self.config.restore_file).parent.parent.resolve() + self.name = self.folder_name.name + def _setup_logger(self): + self.logger = MultiAgentLogger( - self.config, - algorithm_name=self.algorithm_config.associated_class().__name__.lower(), - model_name=self.model_config.associated_class().__name__.lower(), - environment_name=self.task.env_name().lower(), - task_name=self.task.name.lower(), + experiment_name=self.name, + folder_name=str(self.folder_name), + experiment_config=self.config, + algorithm_name=self.algorithm_name, + model_name=self.model_name, + environment_name=self.environment_name, + task_name=self.task_name, group_map=self.group_map, seed=self.seed, ) @@ -283,7 +324,6 @@ def _collection_loop(self): # Training/collection iterations for batch in self.collector: - print(f"Iteration {self.n_iters_performed}") # Logging collection collection_time = time.time() - sampling_start @@ -338,7 +378,7 @@ def _collection_loop(self): "timers/total_time": self.total_time, "counters/current_frames": current_frames, "counters/total_frames": self.total_frames, - "counters/total_iter": self.n_iters_performed, + "counters/iter": self.n_iters_performed, }, step=self.n_iters_performed, ) @@ -350,8 +390,14 @@ def _collection_loop(self): ): self._evaluation_loop(iter=self.n_iters_performed) - self.n_iters_performed += 1 + # End of step self.logger.commit() + if ( + self.config.checkpoint_interval > 0 + and self.n_iters_performed % self.config.checkpoint_interval == 0 + ): + self.save_trainer() + self.n_iters_performed += 1 sampling_start = time.time() self.close() @@ -433,3 +479,44 @@ def callback(env, td): evaluation_time = time.time() - evaluation_start self.logger.log({"timers/evaluation_time": evaluation_time}, step=iter) self.logger.log_evaluation(rollouts, frames, step=iter) + + # Saving trainer state + def state_dict(self) -> OrderedDict: + + state = OrderedDict( + total_time=self.total_time, + total_frames=self.total_frames, + n_iters_performed=self.n_iters_performed, + mean_return=self.mean_return, + ) + state_dict = OrderedDict( + state=state, + collector=self.collector.state_dict(), + **{f"loss_{k}": item.state_dict() for k, item in self.losses.items()}, + **{ + f"buffer_{k}": item.state_dict() + for k, item in self.replay_buffers.items() + }, + ) + return state_dict + + def load_state_dict(self, state_dict: Dict) -> None: + for group in self.group_map.keys(): + self.losses[group].load_state_dict(state_dict[f"loss_{group}"]) + self.replay_buffers[group].load_state_dict(state_dict[f"buffer_{group}"]) + self.collector.load_state_dict(state_dict["collector"]) + self.total_time = state_dict["state"]["total_time"] + self.total_frames = state_dict["state"]["total_frames"] + self.n_iters_performed = state_dict["state"]["n_iters_performed"] + self.mean_return = state_dict["state"]["mean_return"] + + def save_trainer(self) -> None: + checkpoint_folder = self.folder_name / "checkpoints" + checkpoint_folder.mkdir(parents=False, exist_ok=True) + checkpoint_file = checkpoint_folder / f"checkpoint_{self.n_iters_performed}.pt" + torch.save(self.state_dict(), checkpoint_file) + + def load_trainer(self) -> Experiment: + loaded_dict: OrderedDict = torch.load(self.config.restore_file) + self.load_state_dict(loaded_dict) + return self diff --git a/benchmarl/experiment/logger.py b/benchmarl/experiment/logger.py index aacf7fac..aec4d670 100644 --- a/benchmarl/experiment/logger.py +++ b/benchmarl/experiment/logger.py @@ -1,6 +1,5 @@ -import importlib import json -import os + from pathlib import Path from typing import Any, Dict, List, Optional @@ -8,18 +7,15 @@ import torch from tensordict import TensorDictBase -from torchrl.record.loggers import generate_exp_name, get_logger, Logger +from torchrl.record.loggers import get_logger, Logger from torchrl.record.loggers.wandb import WandbLogger -_has_hydra = importlib.util.find_spec("hydra") is not None - -if _has_hydra: - from hydra.core.hydra_config import HydraConfig - class MultiAgentLogger: def __init__( self, + experiment_name: str, + folder_name: str, experiment_config, algorithm_name: str, environment_name: str, @@ -36,17 +32,10 @@ def __init__( self.group_map = group_map self.seed = seed - exp_name = generate_exp_name(f"{algorithm_name}_{task_name}_{model_name}", "") - - if _has_hydra and HydraConfig.initialized(): - cwd = HydraConfig.get().runtime.output_dir - else: - cwd = str(Path(os.getcwd()) / exp_name) - if experiment_config.create_json: self.json_writer = JsonWriter( - folder=cwd, - name=exp_name + ".json", + folder=folder_name, + name=experiment_name + ".json", algorithm_name=algorithm_name, task_name=task_name, environment_name=environment_name, @@ -60,12 +49,12 @@ def __init__( self.loggers.append( get_logger( logger_type=logger_name, - logger_name=cwd, - experiment_name=exp_name, + logger_name=folder_name, + experiment_name=experiment_name, wandb_kwargs={ "group": task_name, "project": "benchmarl", - "id": exp_name, + "id": experiment_name, }, ) ) diff --git a/test/conf/experiment/base_experiment.yaml b/test/conf/experiment/base_experiment.yaml index 0d67fe1c..446dd5dd 100644 --- a/test/conf/experiment/base_experiment.yaml +++ b/test/conf/experiment/base_experiment.yaml @@ -28,3 +28,6 @@ evaluation_episodes: 200 loggers: [] create_json: False + +restore_file: +checkpoint_interval: 0