From b176244f288c8bd01fc8782a6f6e24edbcc61558 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 14 Sep 2023 15:00:13 +0100 Subject: [PATCH] [Logging] Small fixes to logging Signed-off-by: Matteo Bettini --- benchmarl/experiment/experiment.py | 25 ++++++++++++++++----- benchmarl/experiment/logger.py | 35 ++++++++++++++++++------------ benchmarl/hydra_run.py | 6 +++-- 3 files changed, 45 insertions(+), 21 deletions(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 8cd16c8a..5f9d0b3a 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -12,6 +12,7 @@ from torchrl.envs import EnvBase, RewardSum, SerialEnv, TransformedEnv from torchrl.envs.transforms import Compose from torchrl.envs.utils import ExplorationType, set_exploration_type +from tqdm import tqdm from benchmarl.algorithms.common import AlgorithmConfig from benchmarl.environments import Task @@ -115,6 +116,11 @@ def __init__( self._setup() + self.total_time = 0 + self.total_frames = 0 + self.n_iters_performed = 0 + self.mean_return = 0 + @property def on_policy(self) -> bool: return self.algorithm_config.on_policy() @@ -268,9 +274,11 @@ def run(self): self._collection_loop() def _collection_loop(self): - self.total_time = 0 - self.total_frames = 0 - self.n_iters_performed = 0 + + pbar = tqdm( + initial=self.n_iters_performed, + total=self.config.n_iters, + ) sampling_start = time.time() # Training/collection iterations @@ -281,9 +289,11 @@ def _collection_loop(self): collection_time = time.time() - sampling_start current_frames = batch.numel() self.total_frames += current_frames - self.logger.log_collection( + self.mean_return = self.logger.log_collection( batch, self.total_frames, step=self.n_iters_performed ) + pbar.set_description(f"mean return = {self.mean_return}", refresh=False) + pbar.update() # Loop over groups training_start = time.time() @@ -344,7 +354,12 @@ def _collection_loop(self): self.logger.commit() sampling_start = time.time() + self.close() + + def close(self): self.collector.shutdown() + self.test_env.close() + self.logger.finish() def _get_excluded_keys(self, group: str): excluded_keys = [] @@ -375,7 +390,7 @@ def _optimizer_loop(self, group: str) -> TensorDictBase: optimizer.step() optimizer.zero_grad() elif loss_name.startswith("loss"): - assert False + raise AssertionError if self.target_updaters[group] is not None: self.target_updaters[group].step() return training_td diff --git a/benchmarl/experiment/logger.py b/benchmarl/experiment/logger.py index 6ee47879..aacf7fac 100644 --- a/benchmarl/experiment/logger.py +++ b/benchmarl/experiment/logger.py @@ -30,6 +30,7 @@ def __init__( ): self.experiment_config = experiment_config self.algorithm_name = algorithm_name + self.environment_name = environment_name self.task_name = task_name self.model_name = model_name self.group_map = group_map @@ -64,6 +65,7 @@ def __init__( wandb_kwargs={ "group": task_name, "project": "benchmarl", + "id": exp_name, }, ) ) @@ -75,6 +77,7 @@ def log_hparams(self, **kwargs): "algorithm_name": self.algorithm_name, "model_name": self.model_name, "task_name": self.task_name, + "environment_name": self.environment_name, "seed": self.seed, } ) @@ -85,17 +88,15 @@ def log_collection( batch: TensorDictBase, total_frames: int, step: int, - ): + ) -> float: if not len(self.loggers) and self.json_writer is None: return to_log = {} - group_returns = {} + json_metrics = {} for group in self.group_map.keys(): episode_reward = self._get_episode_reward(group, batch) done = self._get_done(group, batch) - group_returns[group + "_return"] = episode_reward.mean(-2)[ - done.any(-2) - ].tolist() + json_metrics[group + "_return"] = episode_reward.mean(-2)[done.any(-2)] reward = self._get_reward(group, batch) episode_reward = episode_reward[done] to_log.update( @@ -115,11 +116,16 @@ def log_collection( for key, value in batch.get((group, "info")).items() } ) + mean_group_return = torch.stack( + [value for key, value in json_metrics.items()], dim=0 + ).mean(0) + json_metrics["return"] = mean_group_return if self.json_writer is not None: self.json_writer.write( - metrics=group_returns, total_frames=total_frames, step=step + metrics=json_metrics, total_frames=total_frames, step=step ) self.log(to_log, step=step) + return mean_group_return.mean().item() def log_training(self, group: str, training_td: TensorDictBase, step: int): if not len(self.loggers): @@ -172,14 +178,7 @@ def log_evaluation( ).unsqueeze(0) for logger in self.loggers: if isinstance(logger, WandbLogger): - import wandb - - logger.experiment.log( - { - "eval/video": wandb.Video(vid, fps=20, format="mp4"), - }, - commit=False, - ) + logger.log_video("eval/video", vid, fps=20, commit=False) else: logger.log_video("eval_video", vid, step=step) @@ -196,6 +195,13 @@ def log(self, dict_to_log: Dict, step: int = None): for key, value in dict_to_log.items(): logger.log_scalar(key.replace("/", "_"), value, step=step) + def finish(self): + for logger in self.loggers: + if isinstance(logger, WandbLogger): + import wandb + + wandb.finish() + def _get_reward( self, group: str, td: TensorDictBase, remove_agent_dim: bool = False ): @@ -251,6 +257,7 @@ def __init__( } def write(self, total_frames: int, metrics: Dict[str, Any], step: int): + metrics = {k: val.tolist() for k, val in metrics.items()} metrics.update({"step_count": total_frames}) step_str = f"step_{step}" if step_str in self.run_data: diff --git a/benchmarl/hydra_run.py b/benchmarl/hydra_run.py index b9e147c2..465b761e 100644 --- a/benchmarl/hydra_run.py +++ b/benchmarl/hydra_run.py @@ -43,10 +43,12 @@ def load_model_from_hydra_config(cfg: DictConfig) -> ModelConfig: @hydra.main(version_base=None, config_path="conf", config_name="config") def hydra_experiment(cfg: DictConfig) -> None: - print("Loaded config:") - print(OmegaConf.to_yaml(cfg)) hydra_choices = HydraConfig.get().runtime.choices task_name = hydra_choices.task + print(f"\nAlgorithm: {hydra_choices.algorithm}, Task: {task_name}") + print("\nLoaded config:\n") + print(OmegaConf.to_yaml(cfg)) + experiment = load_experiment_from_hydra_config( cfg, task_name=task_name,