From 30940f905a73f7a7a6babc662cc2323a2683abe4 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Mon, 9 Sep 2024 23:37:49 +0200 Subject: [PATCH] [Feature] Improved experiment reloading and evaluation (#127) * able to specify save folder when reloading * able to specify save folder when reloading * amend * evaluate * evaluate * ignore * add examples * add examples * add examples * add examples * add examples * docs --- .gitignore | 1 + README.md | 9 +++ .../conf/experiment/base_experiment.yaml | 7 ++- benchmarl/evaluate.py | 21 +++++++ benchmarl/experiment/experiment.py | 63 ++++++++++++------- benchmarl/hydra_config.py | 56 +++++++++++++++++ benchmarl/resume.py | 22 +++++++ docs/source/concepts/features.rst | 45 ++++++++++++- examples/checkpointing/reload_experiment.py | 4 +- examples/checkpointing/reload_experiment.sh | 2 +- examples/checkpointing/resume_experiment.sh | 10 +++ examples/evaluating/evalaute_experiment.py | 26 ++++++++ examples/evaluating/evaluate_experiment.sh | 12 ++++ 13 files changed, 249 insertions(+), 29 deletions(-) create mode 100644 benchmarl/evaluate.py create mode 100644 benchmarl/resume.py create mode 100644 examples/checkpointing/resume_experiment.sh create mode 100644 examples/evaluating/evalaute_experiment.py create mode 100644 examples/evaluating/evaluate_experiment.sh diff --git a/.gitignore b/.gitignore index 24eac773..46d2831e 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ +!/examples/evaluate/outputs/ diff --git a/README.md b/README.md index 68acd265..0feb4a96 100644 --- a/README.md +++ b/README.md @@ -469,6 +469,15 @@ python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters [![Example](https://img.shields.io/badge/Example-blue.svg)](examples/checkpointing/reload_experiment.py) + +There are also ways to **resume** and **evaluate** hydra experiments directly from the file +```bash +python benchmarl/evaluate.py ../outputs/2024-09-09/20-39-31/mappo_balance_mlp__cd977b69_24_09_09-20_39_31/checkpoints/checkpoint_100.pt +``` +```bash +python benchmarl/resume.py ../outputs/2024-09-09/20-39-31/mappo_balance_mlp__cd977b69_24_09_09-20_39_31/checkpoints/checkpoint_100.pt +``` + ### Callbacks Experiments optionally take a list of [`Callback`](benchmarl/experiment/callback.py) which have several methods diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index 30d5d0ef..d1e68afe 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -86,7 +86,7 @@ evaluation_episodes: 10 evaluation_deterministic_actions: True # List of loggers to use, options are: wandb, csv, tensorboard, mflow -loggers: [] +loggers: [csv] # Wandb project name project_name: "benchmarl" # Create a json folder as part of the output in the format of marl-eval @@ -94,9 +94,14 @@ create_json: True # Absolute path to the folder where the experiment will log. # If null, this will default to the hydra output dir (if using hydra) or to the current folder when the script is run (if not). +# If you are reloading an experiment with "restore_file", this will default to the reloaded experiment folder. save_folder: null # Absolute path to a checkpoint file where the experiment was saved. If null the experiment is started fresh. restore_file: null +# Map location given to `torch.load()` when reloading. +# If you are reloading in a cpu-only machine a gpu experiment, you can use `restore_map_location: {"cuda":"cpu"}` +# to map gpu tensors to the cpu +restore_map_location: null # Interval for experiment saving in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch). # Set it to 0 to disable checkpointing checkpoint_interval: 0 diff --git a/benchmarl/evaluate.py b/benchmarl/evaluate.py new file mode 100644 index 00000000..942bc4bb --- /dev/null +++ b/benchmarl/evaluate.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +import argparse +from pathlib import Path + +from benchmarl.hydra_config import reload_experiment_from_file + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Evaluates the experiment from a checkpoint file." + ) + parser.add_argument( + "checkpoint_file", type=str, help="The name of the checkpoint file" + ) + args = parser.parse_args() + checkpoint_file = str(Path(args.checkpoint_file).resolve()) + experiment = reload_experiment_from_file(checkpoint_file) + experiment.evaluate() diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index fa5ac129..e86c52e4 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -14,7 +14,7 @@ from collections import deque, OrderedDict from dataclasses import dataclass, MISSING from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch from tensordict import TensorDictBase @@ -99,6 +99,7 @@ class ExperimentConfig: save_folder: Optional[str] = MISSING restore_file: Optional[str] = MISSING + restore_map_location: Optional[Any] = MISSING checkpoint_interval: int = MISSING checkpoint_at_end: bool = MISSING keep_checkpoints_num: Optional[int] = MISSING @@ -506,33 +507,40 @@ def _setup_name(self): self.task_name = self.task.name.lower() self._checkpointed_files = deque([]) - if self.config.restore_file is not None and self.config.save_folder is not None: - raise ValueError( - "Experiment restore file and save folder have both been specified." - "Do not set a save_folder when you are reloading an experiment as" - "it will by default reloaded into the old folder." - ) - if self.config.restore_file is None: - if self.config.save_folder is not None: - folder_name = Path(self.config.save_folder) + if self.config.save_folder is not None: + # If the user specified a folder for the experiment we use that + save_folder = Path(self.config.save_folder) + else: + # Otherwise, if the user is restoring from a folder, we will save in the folder they are restoring from + if self.config.restore_file is not None: + save_folder = Path( + self.config.restore_file + ).parent.parent.parent.resolve() + # Otherwise, the user is not restoring and did not specify a save_folder so we save in the hydra directory + # of the experiment or in the directory where the experiment was run (if hydra is not used) else: if _has_hydra and HydraConfig.initialized(): - folder_name = Path(HydraConfig.get().runtime.output_dir) + save_folder = Path(HydraConfig.get().runtime.output_dir) else: - folder_name = Path(os.getcwd()) + save_folder = Path(os.getcwd()) + + if self.config.restore_file is None: self.name = generate_exp_name( f"{self.algorithm_name}_{self.task_name}_{self.model_name}", "" ) - self.folder_name = folder_name / self.name - if ( - len(self.config.loggers) - or self.config.checkpoint_interval > 0 - or self.config.create_json - ): - self.folder_name.mkdir(parents=False, exist_ok=False) + self.folder_name = save_folder / self.name + else: - self.folder_name = Path(self.config.restore_file).parent.parent.resolve() - self.name = self.folder_name.name + # If restoring, we use the name of the previous experiment + self.name = Path(self.config.restore_file).parent.parent.resolve().name + self.folder_name = save_folder / self.name + + if ( + len(self.config.loggers) + or self.config.checkpoint_interval > 0 + or self.config.create_json + ): + self.folder_name.mkdir(parents=False, exist_ok=True) def _setup_logger(self): self.logger = Logger( @@ -570,6 +578,15 @@ def run(self): self.close() raise err + def evaluate(self): + """Run just the evaluation loop once.""" + self._evaluation_loop() + self.logger.commit() + print( + f"Evaluation results logged to loggers={self.config.loggers}" + f"{' and to a json file in the experiment folder.' if self.config.create_json else ''}" + ) + def _collection_loop(self): pbar = tqdm( initial=self.n_iters_performed, @@ -883,6 +900,8 @@ def _save_experiment(self) -> None: def _load_experiment(self) -> Experiment: """Load trainer from checkpoint""" - loaded_dict: OrderedDict = torch.load(self.config.restore_file) + loaded_dict: OrderedDict = torch.load( + self.config.restore_file, map_location=self.config.restore_map_location + ) self.load_state_dict(loaded_dict) return self diff --git a/benchmarl/hydra_config.py b/benchmarl/hydra_config.py index f83c25ad..8c3fa43a 100644 --- a/benchmarl/hydra_config.py +++ b/benchmarl/hydra_config.py @@ -5,6 +5,7 @@ # import importlib from dataclasses import is_dataclass +from pathlib import Path from benchmarl.algorithms.common import AlgorithmConfig from benchmarl.environments import Task, task_config_registry @@ -16,6 +17,7 @@ _has_hydra = importlib.util.find_spec("hydra") is not None if _has_hydra: + from hydra import compose, initialize, initialize_config_dir from omegaconf import DictConfig, OmegaConf @@ -121,3 +123,57 @@ def load_model_config_from_hydra(cfg: DictConfig) -> ModelConfig: OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) ) ) + + +def _find_hydra_folder(restore_file: str) -> str: + """Given the restore file, look for the .hydra folder max three levels above it.""" + current_folder = Path(restore_file).parent.resolve() + for _ in range(3): + hydra_dir = current_folder / ".hydra" + if hydra_dir.exists() and hydra_dir.is_dir(): + return str(hydra_dir) + current_folder = current_folder.parent + raise ValueError( + ".hydra folder not found (should be max 3 levels above checkpoint file" + ) + + +def reload_experiment_from_file(restore_file: str) -> Experiment: + """Reloads the experiment from a given restore file. + + Requires a ``.hydra`` folder containing ``config.yaml``, ``hydra.yaml``, and ``overrides.yaml`` + at max three directory levels higher than the checkpoint file. This should be automatically created by hydra. + + Args: + restore_file (str): The checkpoint file of the experiment reload. + + """ + hydra_folder = _find_hydra_folder(restore_file) + with initialize( + version_base=None, + config_path="conf", + ): + cfg = compose( + config_name="config", + overrides=OmegaConf.load(Path(hydra_folder) / "overrides.yaml"), + return_hydra_config=True, + ) + task_name = cfg.hydra.runtime.choices.task + algorithm_name = cfg.hydra.runtime.choices.algorithm + with initialize_config_dir(version_base=None, config_dir=hydra_folder): + cfg_loaded = dict(compose(config_name="config")) + + for key in ("experiment", "algorithm", "task", "model", "critic_model"): + cfg[key].update(cfg_loaded[key]) + cfg_loaded.pop(key) + + cfg.update(cfg_loaded) + del cfg.hydra + cfg.experiment.restore_file = restore_file + + print("\nReloaded experiment with:") + print(f"\nAlgorithm: {algorithm_name}, Task: {task_name}") + print("\nLoaded config:\n") + print(OmegaConf.to_yaml(cfg)) + + return load_experiment_from_hydra(cfg, task_name=task_name) diff --git a/benchmarl/resume.py b/benchmarl/resume.py new file mode 100644 index 00000000..fb5d3898 --- /dev/null +++ b/benchmarl/resume.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +import argparse +from pathlib import Path + +from benchmarl.hydra_config import reload_experiment_from_file + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Resumes the experiment from a checkpoint file." + ) + parser.add_argument( + "checkpoint_file", type=str, help="The name of the checkpoint file" + ) + args = parser.parse_args() + checkpoint_file = str(Path(args.checkpoint_file).resolve()) + + experiment = reload_experiment_from_file(checkpoint_file) + experiment.run() diff --git a/docs/source/concepts/features.rst b/docs/source/concepts/features.rst index 1f09478d..cec3ce8d 100644 --- a/docs/source/concepts/features.rst +++ b/docs/source/concepts/features.rst @@ -38,18 +38,57 @@ Their checkpoints will be stored in a ``"checkpoints"`` folder within the experi python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=3 experiment.on_policy_collected_frames_per_batch=100 experiment.checkpoint_interval=100 +.. python_example_button:: + https://github.com/facebookresearch/BenchMARL/blob/main/examples/checkpointing/reload_experiment.py -To load from a checkpoint, pass the absolute checkpoint file name to ``experiment.restore_file``. +Reloading +--------- -.. code-block:: console +To load from a checkpoint, you can do it in multiple ways: - python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=6 experiment.on_policy_collected_frames_per_batch=100 experiment.restore_file="/hydra/experiment/folder/checkpoint/checkpoint_300.pt" +You can pass the absolute checkpoint file name to ``experiment.restore_file``. +.. code-block:: console + python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=6 experiment.on_policy_collected_frames_per_batch=100 experiment.restore_file="/hydra/experiment/folder/checkpoint/checkpoint_300.pt" .. python_example_button:: https://github.com/facebookresearch/BenchMARL/blob/main/examples/checkpointing/reload_experiment.py +If you do not need to change the config, you can also just resume from the checkpoint file with: + +.. code-block:: console + + python benchmarl/resume.py ../outputs/2024-09-09/20-39-31/mappo_balance_mlp__cd977b69_24_09_09-20_39_31/checkpoints/checkpoint_100.pt + +In Python, this is equivalent to: + +.. code-block:: python + + from benchmarl.hydra_config import reload_experiment_from_file + experiment = reload_experiment_from_file(checkpoint_file) + experiment.run() + + +Evaluating +---------- + +To evaluate an experiment, you can: + +.. code-block:: python + + from benchmarl.hydra_config import reload_experiment_from_file + experiment = reload_experiment_from_file(checkpoint_file) + experiment.evaluate() + +This will run an iteration of evaluation, logging it to the experiment loggers (and to json if ``create_json==True``. + +There is a command line script which automates this: + +.. code-block:: console + + python benchmarl/evaluate.py ../outputs/2024-09-09/20-39-31/mappo_balance_mlp__cd977b69_24_09_09-20_39_31/checkpoints/checkpoint_100.pt + Callbacks --------- diff --git a/examples/checkpointing/reload_experiment.py b/examples/checkpointing/reload_experiment.py index 705662d8..0cf6b54d 100644 --- a/examples/checkpointing/reload_experiment.py +++ b/examples/checkpointing/reload_experiment.py @@ -38,13 +38,13 @@ ) experiment.run() - # Now we tell it where to restore from + # Now we tell it where to restore from experiment_config.restore_file = ( experiment.folder_name / "checkpoints" / f"checkpoint_{experiment.total_frames}.pt" ) - # The experiment will be saved in the ame folder as the one it is restoring from + # The experiment will be saved in the same folder as the one it is restoring from experiment_config.save_folder = None # Let's do 3 more iters experiment_config.max_n_iters += 3 diff --git a/examples/checkpointing/reload_experiment.sh b/examples/checkpointing/reload_experiment.sh index b461d5e9..1052a1a8 100644 --- a/examples/checkpointing/reload_experiment.sh +++ b/examples/checkpointing/reload_experiment.sh @@ -7,4 +7,4 @@ # python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=3 experiment.on_policy_collected_frames_per_batch=100 experiment.checkpoint_interval=100 -python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=6 experiment.on_policy_collected_frames_per_batch=100 experiment.restore_file="/hydra/experiment/folder/checkpoint/checkpoint_300.pt" +python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=6 experiment.on_policy_collected_frames_per_batch=100 experiment.restore_file="/hydra_experiment_folder/checkpoint/checkpoint_300.pt" diff --git a/examples/checkpointing/resume_experiment.sh b/examples/checkpointing/resume_experiment.sh new file mode 100644 index 00000000..8023b297 --- /dev/null +++ b/examples/checkpointing/resume_experiment.sh @@ -0,0 +1,10 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# + +python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=3 experiment.on_policy_collected_frames_per_batch=100 experiment.checkpoint_interval=100 +python benchmarl/resume.py /hydra_experiment_folder/checkpoint/checkpoint_200.pt diff --git a/examples/evaluating/evalaute_experiment.py b/examples/evaluating/evalaute_experiment.py new file mode 100644 index 00000000..171ca73d --- /dev/null +++ b/examples/evaluating/evalaute_experiment.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from pathlib import Path + +from benchmarl.hydra_config import reload_experiment_from_file + +if __name__ == "__main__": + + # Let's assume that we have run an experiment with + # `python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=2 experiment.on_policy_collected_frames_per_batch=100 experiment.checkpoint_interval=100` + # and we have obtained + # "outputs/2024-09-09/20-39-31/mappo_balance_mlp__cd977b69_24_09_09-20_39_31/checkpoints/checkpoint_100.pt"" + + # Now we tell it where to restore from + current_folder = Path(__file__).parent.absolute() + restore_file = ( + current_folder + / "outputs/2024-09-09/20-39-31/mappo_balance_mlp__cd977b69_24_09_09-20_39_31/checkpoints/checkpoint_100.pt" + ) + + experiment = reload_experiment_from_file(str(restore_file)) + experiment.evaluate() diff --git a/examples/evaluating/evaluate_experiment.sh b/examples/evaluating/evaluate_experiment.sh new file mode 100644 index 00000000..3aacf78a --- /dev/null +++ b/examples/evaluating/evaluate_experiment.sh @@ -0,0 +1,12 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# + +# Assume we have run the experiment with +# python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters=2 experiment.on_policy_collected_frames_per_batch=100 experiment.checkpoint_interval=100 +# Evaluate it at the given checkpoint +python benchmarl/evaluate.py ../outputs/2024-09-09/20-39-31/mappo_balance_mlp__cd977b69_24_09_09-20_39_31/checkpoints/checkpoint_100.pt