Skip to content

Commit

Permalink
[Checkpointing] Save/load trainer (#12)
Browse files Browse the repository at this point in the history
* amend

Signed-off-by: Matteo Bettini <[email protected]>

* amend

Signed-off-by: Matteo Bettini <[email protected]>

* amend

Signed-off-by: Matteo Bettini <[email protected]>

* amend

Signed-off-by: Matteo Bettini <[email protected]>

* amend

Signed-off-by: Matteo Bettini <[email protected]>

* amend

Signed-off-by: Matteo Bettini <[email protected]>

* ippo one optim

Signed-off-by: Matteo Bettini <[email protected]>

* revert ippo

Signed-off-by: Matteo Bettini <[email protected]>

* cfg

Signed-off-by: Matteo Bettini <[email protected]>

* update task config

Signed-off-by: Matteo Bettini <[email protected]>

* remove callback

Signed-off-by: Matteo Bettini <[email protected]>

* readme loggers

Signed-off-by: Matteo Bettini <[email protected]>

* amend

Signed-off-by: Matteo Bettini <[email protected]>

* allow wandb loading

Signed-off-by: Matteo Bettini <[email protected]>

* amend

Signed-off-by: Matteo Bettini <[email protected]>

* add finishing logging

Signed-off-by: Matteo Bettini <[email protected]>

* more printing

Signed-off-by: Matteo Bettini <[email protected]>

* more printing

Signed-off-by: Matteo Bettini <[email protected]>

* fixes

Signed-off-by: Matteo Bettini <[email protected]>

* fixes

Signed-off-by: Matteo Bettini <[email protected]>

* fixes

Signed-off-by: Matteo Bettini <[email protected]>

* add .pt

Signed-off-by: Matteo Bettini <[email protected]>

* move saving

Signed-off-by: Matteo Bettini <[email protected]>

* amend

Signed-off-by: Matteo Bettini <[email protected]>

---------

Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Sep 15, 2023
1 parent b176244 commit 0cff333
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 33 deletions.
7 changes: 5 additions & 2 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ defaults:
- experiment_config
- _self_

sampling_device: "cuda"
train_device: "cuda"
sampling_device: "cpu"
train_device: "cpu"
gamma: 0.99
polyak_tau: 0.005
lr: 0.00005
Expand All @@ -28,3 +28,6 @@ evaluation_episodes: 200

loggers: [wandb]
create_json: True

restore_file: null
checkpoint_interval: 0
109 changes: 98 additions & 11 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import pathlib
from __future__ import annotations

import importlib
import os
import time
from collections import OrderedDict
from dataclasses import dataclass, MISSING
from typing import List, Optional
from pathlib import Path
from typing import Dict, List, Optional

import torch

Expand All @@ -12,6 +17,7 @@
from torchrl.envs import EnvBase, RewardSum, SerialEnv, TransformedEnv
from torchrl.envs.transforms import Compose
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name
from tqdm import tqdm

from benchmarl.algorithms.common import AlgorithmConfig
Expand All @@ -20,6 +26,10 @@
from benchmarl.models.common import ModelConfig
from benchmarl.utils import read_yaml_config

_has_hydra = importlib.util.find_spec("hydra") is not None
if _has_hydra:
from hydra.core.hydra_config import HydraConfig


@dataclass
class ExperimentConfig:
Expand Down Expand Up @@ -51,6 +61,9 @@ class ExperimentConfig:
loggers: List[str] = MISSING
create_json: bool = MISSING

restore_file: Optional[str] = MISSING
checkpoint_interval: float = MISSING

def train_batch_size(self, on_policy: bool) -> int:
return (
self.collected_frames_per_batch
Expand Down Expand Up @@ -88,7 +101,7 @@ def exploration_annealing_num_frames(self) -> int:
def get_from_yaml(path: Optional[str] = None):
if path is None:
yaml_path = (
pathlib.Path(__file__).parent.parent
Path(__file__).parent.parent
/ "conf"
/ "experiment"
/ "base_experiment.yaml"
Expand Down Expand Up @@ -121,6 +134,9 @@ def __init__(
self.n_iters_performed = 0
self.mean_return = 0

if self.config.restore_file is not None:
self.load_trainer()

@property
def on_policy(self) -> bool:
return self.algorithm_config.on_policy()
Expand All @@ -130,6 +146,7 @@ def _setup(self):
self._setup_task()
self._setup_algorithm()
self._setup_collector()
self._setup_name()
self._setup_logger()

def _set_action_type(self):
Expand Down Expand Up @@ -251,13 +268,37 @@ def _setup_collector(self):
total_frames=self.config.total_frames,
)

def _setup_name(self):
self.algorithm_name = self.algorithm_config.associated_class().__name__.lower()
self.model_name = self.model_config.associated_class().__name__.lower()
self.environment_name = self.task.env_name().lower()
self.task_name = self.task.name.lower()

if self.config.restore_file is None:
if _has_hydra and HydraConfig.initialized():
folder_name = Path(HydraConfig.get().runtime.output_dir)
else:
folder_name = Path(os.getcwd())
self.name = generate_exp_name(
f"{self.algorithm_name}_{self.task_name}_{self.model_name}", ""
)
self.folder_name = folder_name / self.name
self.folder_name.mkdir(parents=False, exist_ok=False)

else:
self.folder_name = Path(self.config.restore_file).parent.parent.resolve()
self.name = self.folder_name.name

def _setup_logger(self):

self.logger = MultiAgentLogger(
self.config,
algorithm_name=self.algorithm_config.associated_class().__name__.lower(),
model_name=self.model_config.associated_class().__name__.lower(),
environment_name=self.task.env_name().lower(),
task_name=self.task.name.lower(),
experiment_name=self.name,
folder_name=str(self.folder_name),
experiment_config=self.config,
algorithm_name=self.algorithm_name,
model_name=self.model_name,
environment_name=self.environment_name,
task_name=self.task_name,
group_map=self.group_map,
seed=self.seed,
)
Expand All @@ -283,7 +324,6 @@ def _collection_loop(self):

# Training/collection iterations
for batch in self.collector:
print(f"Iteration {self.n_iters_performed}")

# Logging collection
collection_time = time.time() - sampling_start
Expand Down Expand Up @@ -338,7 +378,7 @@ def _collection_loop(self):
"timers/total_time": self.total_time,
"counters/current_frames": current_frames,
"counters/total_frames": self.total_frames,
"counters/total_iter": self.n_iters_performed,
"counters/iter": self.n_iters_performed,
},
step=self.n_iters_performed,
)
Expand All @@ -350,8 +390,14 @@ def _collection_loop(self):
):
self._evaluation_loop(iter=self.n_iters_performed)

self.n_iters_performed += 1
# End of step
self.logger.commit()
if (
self.config.checkpoint_interval > 0
and self.n_iters_performed % self.config.checkpoint_interval == 0
):
self.save_trainer()
self.n_iters_performed += 1
sampling_start = time.time()

self.close()
Expand Down Expand Up @@ -433,3 +479,44 @@ def callback(env, td):
evaluation_time = time.time() - evaluation_start
self.logger.log({"timers/evaluation_time": evaluation_time}, step=iter)
self.logger.log_evaluation(rollouts, frames, step=iter)

# Saving trainer state
def state_dict(self) -> OrderedDict:

state = OrderedDict(
total_time=self.total_time,
total_frames=self.total_frames,
n_iters_performed=self.n_iters_performed,
mean_return=self.mean_return,
)
state_dict = OrderedDict(
state=state,
collector=self.collector.state_dict(),
**{f"loss_{k}": item.state_dict() for k, item in self.losses.items()},
**{
f"buffer_{k}": item.state_dict()
for k, item in self.replay_buffers.items()
},
)
return state_dict

def load_state_dict(self, state_dict: Dict) -> None:
for group in self.group_map.keys():
self.losses[group].load_state_dict(state_dict[f"loss_{group}"])
self.replay_buffers[group].load_state_dict(state_dict[f"buffer_{group}"])
self.collector.load_state_dict(state_dict["collector"])
self.total_time = state_dict["state"]["total_time"]
self.total_frames = state_dict["state"]["total_frames"]
self.n_iters_performed = state_dict["state"]["n_iters_performed"]
self.mean_return = state_dict["state"]["mean_return"]

def save_trainer(self) -> None:
checkpoint_folder = self.folder_name / "checkpoints"
checkpoint_folder.mkdir(parents=False, exist_ok=True)
checkpoint_file = checkpoint_folder / f"checkpoint_{self.n_iters_performed}.pt"
torch.save(self.state_dict(), checkpoint_file)

def load_trainer(self) -> Experiment:
loaded_dict: OrderedDict = torch.load(self.config.restore_file)
self.load_state_dict(loaded_dict)
return self
29 changes: 9 additions & 20 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
import importlib
import json
import os

from pathlib import Path
from typing import Any, Dict, List, Optional

import numpy as np
import torch

from tensordict import TensorDictBase
from torchrl.record.loggers import generate_exp_name, get_logger, Logger
from torchrl.record.loggers import get_logger, Logger
from torchrl.record.loggers.wandb import WandbLogger

_has_hydra = importlib.util.find_spec("hydra") is not None

if _has_hydra:
from hydra.core.hydra_config import HydraConfig


class MultiAgentLogger:
def __init__(
self,
experiment_name: str,
folder_name: str,
experiment_config,
algorithm_name: str,
environment_name: str,
Expand All @@ -36,17 +32,10 @@ def __init__(
self.group_map = group_map
self.seed = seed

exp_name = generate_exp_name(f"{algorithm_name}_{task_name}_{model_name}", "")

if _has_hydra and HydraConfig.initialized():
cwd = HydraConfig.get().runtime.output_dir
else:
cwd = str(Path(os.getcwd()) / exp_name)

if experiment_config.create_json:
self.json_writer = JsonWriter(
folder=cwd,
name=exp_name + ".json",
folder=folder_name,
name=experiment_name + ".json",
algorithm_name=algorithm_name,
task_name=task_name,
environment_name=environment_name,
Expand All @@ -60,12 +49,12 @@ def __init__(
self.loggers.append(
get_logger(
logger_type=logger_name,
logger_name=cwd,
experiment_name=exp_name,
logger_name=folder_name,
experiment_name=experiment_name,
wandb_kwargs={
"group": task_name,
"project": "benchmarl",
"id": exp_name,
"id": experiment_name,
},
)
)
Expand Down
3 changes: 3 additions & 0 deletions test/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ evaluation_episodes: 200

loggers: []
create_json: False

restore_file:
checkpoint_interval: 0

0 comments on commit 0cff333

Please sign in to comment.