Skip to content

Commit

Permalink
[Feature] Improved experiment reloading and evaluation (#127)
Browse files Browse the repository at this point in the history
* able to specify save folder when reloading

* able to specify save folder when reloading

* amend

* evaluate

* evaluate

* ignore

* add examples

* add examples

* add examples

* add examples

* add examples

* docs
  • Loading branch information
matteobettini committed Sep 9, 2024
1 parent 308228e commit 30940f9
Show file tree
Hide file tree
Showing 13 changed files with 249 additions and 29 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,4 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
!/examples/evaluate/outputs/
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,15 @@ python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters

[![Example](https://img.shields.io/badge/Example-blue.svg)](examples/checkpointing/reload_experiment.py)


There are also ways to **resume** and **evaluate** hydra experiments directly from the file
```bash
python benchmarl/evaluate.py ../outputs/2024-09-09/20-39-31/mappo_balance_mlp__cd977b69_24_09_09-20_39_31/checkpoints/checkpoint_100.pt
```
```bash
python benchmarl/resume.py ../outputs/2024-09-09/20-39-31/mappo_balance_mlp__cd977b69_24_09_09-20_39_31/checkpoints/checkpoint_100.pt
```

### Callbacks

Experiments optionally take a list of [`Callback`](benchmarl/experiment/callback.py) which have several methods
Expand Down
7 changes: 6 additions & 1 deletion benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,22 @@ evaluation_episodes: 10
evaluation_deterministic_actions: True

# List of loggers to use, options are: wandb, csv, tensorboard, mflow
loggers: []
loggers: [csv]
# Wandb project name
project_name: "benchmarl"
# Create a json folder as part of the output in the format of marl-eval
create_json: True

# Absolute path to the folder where the experiment will log.
# If null, this will default to the hydra output dir (if using hydra) or to the current folder when the script is run (if not).
# If you are reloading an experiment with "restore_file", this will default to the reloaded experiment folder.
save_folder: null
# Absolute path to a checkpoint file where the experiment was saved. If null the experiment is started fresh.
restore_file: null
# Map location given to `torch.load()` when reloading.
# If you are reloading in a cpu-only machine a gpu experiment, you can use `restore_map_location: {"cuda":"cpu"}`
# to map gpu tensors to the cpu
restore_map_location: null
# Interval for experiment saving in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch).
# Set it to 0 to disable checkpointing
checkpoint_interval: 0
Expand Down
21 changes: 21 additions & 0 deletions benchmarl/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
from pathlib import Path

from benchmarl.hydra_config import reload_experiment_from_file

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Evaluates the experiment from a checkpoint file."
)
parser.add_argument(
"checkpoint_file", type=str, help="The name of the checkpoint file"
)
args = parser.parse_args()
checkpoint_file = str(Path(args.checkpoint_file).resolve())
experiment = reload_experiment_from_file(checkpoint_file)
experiment.evaluate()
63 changes: 41 additions & 22 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from collections import deque, OrderedDict
from dataclasses import dataclass, MISSING
from pathlib import Path
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import torch
from tensordict import TensorDictBase
Expand Down Expand Up @@ -99,6 +99,7 @@ class ExperimentConfig:

save_folder: Optional[str] = MISSING
restore_file: Optional[str] = MISSING
restore_map_location: Optional[Any] = MISSING
checkpoint_interval: int = MISSING
checkpoint_at_end: bool = MISSING
keep_checkpoints_num: Optional[int] = MISSING
Expand Down Expand Up @@ -506,33 +507,40 @@ def _setup_name(self):
self.task_name = self.task.name.lower()
self._checkpointed_files = deque([])

if self.config.restore_file is not None and self.config.save_folder is not None:
raise ValueError(
"Experiment restore file and save folder have both been specified."
"Do not set a save_folder when you are reloading an experiment as"
"it will by default reloaded into the old folder."
)
if self.config.restore_file is None:
if self.config.save_folder is not None:
folder_name = Path(self.config.save_folder)
if self.config.save_folder is not None:
# If the user specified a folder for the experiment we use that
save_folder = Path(self.config.save_folder)
else:
# Otherwise, if the user is restoring from a folder, we will save in the folder they are restoring from
if self.config.restore_file is not None:
save_folder = Path(
self.config.restore_file
).parent.parent.parent.resolve()
# Otherwise, the user is not restoring and did not specify a save_folder so we save in the hydra directory
# of the experiment or in the directory where the experiment was run (if hydra is not used)
else:
if _has_hydra and HydraConfig.initialized():
folder_name = Path(HydraConfig.get().runtime.output_dir)
save_folder = Path(HydraConfig.get().runtime.output_dir)
else:
folder_name = Path(os.getcwd())
save_folder = Path(os.getcwd())

if self.config.restore_file is None:
self.name = generate_exp_name(
f"{self.algorithm_name}_{self.task_name}_{self.model_name}", ""
)
self.folder_name = folder_name / self.name
if (
len(self.config.loggers)
or self.config.checkpoint_interval > 0
or self.config.create_json
):
self.folder_name.mkdir(parents=False, exist_ok=False)
self.folder_name = save_folder / self.name

else:
self.folder_name = Path(self.config.restore_file).parent.parent.resolve()
self.name = self.folder_name.name
# If restoring, we use the name of the previous experiment
self.name = Path(self.config.restore_file).parent.parent.resolve().name
self.folder_name = save_folder / self.name

if (
len(self.config.loggers)
or self.config.checkpoint_interval > 0
or self.config.create_json
):
self.folder_name.mkdir(parents=False, exist_ok=True)

def _setup_logger(self):
self.logger = Logger(
Expand Down Expand Up @@ -570,6 +578,15 @@ def run(self):
self.close()
raise err

def evaluate(self):
"""Run just the evaluation loop once."""
self._evaluation_loop()
self.logger.commit()
print(
f"Evaluation results logged to loggers={self.config.loggers}"
f"{' and to a json file in the experiment folder.' if self.config.create_json else ''}"
)

def _collection_loop(self):
pbar = tqdm(
initial=self.n_iters_performed,
Expand Down Expand Up @@ -883,6 +900,8 @@ def _save_experiment(self) -> None:

def _load_experiment(self) -> Experiment:
"""Load trainer from checkpoint"""
loaded_dict: OrderedDict = torch.load(self.config.restore_file)
loaded_dict: OrderedDict = torch.load(
self.config.restore_file, map_location=self.config.restore_map_location
)
self.load_state_dict(loaded_dict)
return self
56 changes: 56 additions & 0 deletions benchmarl/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
import importlib
from dataclasses import is_dataclass
from pathlib import Path

from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task, task_config_registry
Expand All @@ -16,6 +17,7 @@
_has_hydra = importlib.util.find_spec("hydra") is not None

if _has_hydra:
from hydra import compose, initialize, initialize_config_dir
from omegaconf import DictConfig, OmegaConf


Expand Down Expand Up @@ -121,3 +123,57 @@ def load_model_config_from_hydra(cfg: DictConfig) -> ModelConfig:
OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
)
)


def _find_hydra_folder(restore_file: str) -> str:
"""Given the restore file, look for the .hydra folder max three levels above it."""
current_folder = Path(restore_file).parent.resolve()
for _ in range(3):
hydra_dir = current_folder / ".hydra"
if hydra_dir.exists() and hydra_dir.is_dir():
return str(hydra_dir)
current_folder = current_folder.parent
raise ValueError(
".hydra folder not found (should be max 3 levels above checkpoint file"
)


def reload_experiment_from_file(restore_file: str) -> Experiment:
"""Reloads the experiment from a given restore file.
Requires a ``.hydra`` folder containing ``config.yaml``, ``hydra.yaml``, and ``overrides.yaml``
at max three directory levels higher than the checkpoint file. This should be automatically created by hydra.
Args:
restore_file (str): The checkpoint file of the experiment reload.
"""
hydra_folder = _find_hydra_folder(restore_file)
with initialize(
version_base=None,
config_path="conf",
):
cfg = compose(
config_name="config",
overrides=OmegaConf.load(Path(hydra_folder) / "overrides.yaml"),
return_hydra_config=True,
)
task_name = cfg.hydra.runtime.choices.task
algorithm_name = cfg.hydra.runtime.choices.algorithm
with initialize_config_dir(version_base=None, config_dir=hydra_folder):
cfg_loaded = dict(compose(config_name="config"))

for key in ("experiment", "algorithm", "task", "model", "critic_model"):
cfg[key].update(cfg_loaded[key])
cfg_loaded.pop(key)

cfg.update(cfg_loaded)
del cfg.hydra
cfg.experiment.restore_file = restore_file

print("\nReloaded experiment with:")
print(f"\nAlgorithm: {algorithm_name}, Task: {task_name}")
print("\nLoaded config:\n")
print(OmegaConf.to_yaml(cfg))

return load_experiment_from_hydra(cfg, task_name=task_name)
22 changes: 22 additions & 0 deletions benchmarl/resume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
from pathlib import Path

from benchmarl.hydra_config import reload_experiment_from_file

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Resumes the experiment from a checkpoint file."
)
parser.add_argument(
"checkpoint_file", type=str, help="The name of the checkpoint file"
)
args = parser.parse_args()
checkpoint_file = str(Path(args.checkpoint_file).resolve())

experiment = reload_experiment_from_file(checkpoint_file)
experiment.run()
45 changes: 42 additions & 3 deletions docs/source/concepts/features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,57 @@ Their checkpoints will be stored in a ``"checkpoints"`` folder within the experi
python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=3 experiment.on_policy_collected_frames_per_batch=100 experiment.checkpoint_interval=100
.. python_example_button::
https://github.com/facebookresearch/BenchMARL/blob/main/examples/checkpointing/reload_experiment.py

To load from a checkpoint, pass the absolute checkpoint file name to ``experiment.restore_file``.
Reloading
---------

.. code-block:: console
To load from a checkpoint, you can do it in multiple ways:

python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=6 experiment.on_policy_collected_frames_per_batch=100 experiment.restore_file="/hydra/experiment/folder/checkpoint/checkpoint_300.pt"
You can pass the absolute checkpoint file name to ``experiment.restore_file``.

.. code-block:: console
python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=6 experiment.on_policy_collected_frames_per_batch=100 experiment.restore_file="/hydra/experiment/folder/checkpoint/checkpoint_300.pt"
.. python_example_button::
https://github.com/facebookresearch/BenchMARL/blob/main/examples/checkpointing/reload_experiment.py

If you do not need to change the config, you can also just resume from the checkpoint file with:

.. code-block:: console
python benchmarl/resume.py ../outputs/2024-09-09/20-39-31/mappo_balance_mlp__cd977b69_24_09_09-20_39_31/checkpoints/checkpoint_100.pt
In Python, this is equivalent to:

.. code-block:: python
from benchmarl.hydra_config import reload_experiment_from_file
experiment = reload_experiment_from_file(checkpoint_file)
experiment.run()
Evaluating
----------

To evaluate an experiment, you can:

.. code-block:: python
from benchmarl.hydra_config import reload_experiment_from_file
experiment = reload_experiment_from_file(checkpoint_file)
experiment.evaluate()
This will run an iteration of evaluation, logging it to the experiment loggers (and to json if ``create_json==True``.

There is a command line script which automates this:

.. code-block:: console
python benchmarl/evaluate.py ../outputs/2024-09-09/20-39-31/mappo_balance_mlp__cd977b69_24_09_09-20_39_31/checkpoints/checkpoint_100.pt
Callbacks
---------
Expand Down
4 changes: 2 additions & 2 deletions examples/checkpointing/reload_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@
)
experiment.run()

# Now we tell it where to restore from
# Now we tell it where to restore from
experiment_config.restore_file = (
experiment.folder_name
/ "checkpoints"
/ f"checkpoint_{experiment.total_frames}.pt"
)
# The experiment will be saved in the ame folder as the one it is restoring from
# The experiment will be saved in the same folder as the one it is restoring from
experiment_config.save_folder = None
# Let's do 3 more iters
experiment_config.max_n_iters += 3
Expand Down
2 changes: 1 addition & 1 deletion examples/checkpointing/reload_experiment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
#

python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=3 experiment.on_policy_collected_frames_per_batch=100 experiment.checkpoint_interval=100
python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=6 experiment.on_policy_collected_frames_per_batch=100 experiment.restore_file="/hydra/experiment/folder/checkpoint/checkpoint_300.pt"
python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=6 experiment.on_policy_collected_frames_per_batch=100 experiment.restore_file="/hydra_experiment_folder/checkpoint/checkpoint_300.pt"
10 changes: 10 additions & 0 deletions examples/checkpointing/resume_experiment.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
#

python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=3 experiment.on_policy_collected_frames_per_batch=100 experiment.checkpoint_interval=100
python benchmarl/resume.py /hydra_experiment_folder/checkpoint/checkpoint_200.pt
26 changes: 26 additions & 0 deletions examples/evaluating/evalaute_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from pathlib import Path

from benchmarl.hydra_config import reload_experiment_from_file

if __name__ == "__main__":

# Let's assume that we have run an experiment with
# `python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=2 experiment.on_policy_collected_frames_per_batch=100 experiment.checkpoint_interval=100`
# and we have obtained
# "outputs/2024-09-09/20-39-31/mappo_balance_mlp__cd977b69_24_09_09-20_39_31/checkpoints/checkpoint_100.pt""

# Now we tell it where to restore from
current_folder = Path(__file__).parent.absolute()
restore_file = (
current_folder
/ "outputs/2024-09-09/20-39-31/mappo_balance_mlp__cd977b69_24_09_09-20_39_31/checkpoints/checkpoint_100.pt"
)

experiment = reload_experiment_from_file(str(restore_file))
experiment.evaluate()
Loading

0 comments on commit 30940f9

Please sign in to comment.