Skip to content

Commit

Permalink
[Logging] Small fixes to logging
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Sep 14, 2023
1 parent 040b949 commit b176244
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 21 deletions.
25 changes: 20 additions & 5 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torchrl.envs import EnvBase, RewardSum, SerialEnv, TransformedEnv
from torchrl.envs.transforms import Compose
from torchrl.envs.utils import ExplorationType, set_exploration_type
from tqdm import tqdm

from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task
Expand Down Expand Up @@ -115,6 +116,11 @@ def __init__(

self._setup()

self.total_time = 0
self.total_frames = 0
self.n_iters_performed = 0
self.mean_return = 0

@property
def on_policy(self) -> bool:
return self.algorithm_config.on_policy()
Expand Down Expand Up @@ -268,9 +274,11 @@ def run(self):
self._collection_loop()

def _collection_loop(self):
self.total_time = 0
self.total_frames = 0
self.n_iters_performed = 0

pbar = tqdm(
initial=self.n_iters_performed,
total=self.config.n_iters,
)
sampling_start = time.time()

# Training/collection iterations
Expand All @@ -281,9 +289,11 @@ def _collection_loop(self):
collection_time = time.time() - sampling_start
current_frames = batch.numel()
self.total_frames += current_frames
self.logger.log_collection(
self.mean_return = self.logger.log_collection(
batch, self.total_frames, step=self.n_iters_performed
)
pbar.set_description(f"mean return = {self.mean_return}", refresh=False)
pbar.update()

# Loop over groups
training_start = time.time()
Expand Down Expand Up @@ -344,7 +354,12 @@ def _collection_loop(self):
self.logger.commit()
sampling_start = time.time()

self.close()

def close(self):
self.collector.shutdown()
self.test_env.close()
self.logger.finish()

def _get_excluded_keys(self, group: str):
excluded_keys = []
Expand Down Expand Up @@ -375,7 +390,7 @@ def _optimizer_loop(self, group: str) -> TensorDictBase:
optimizer.step()
optimizer.zero_grad()
elif loss_name.startswith("loss"):
assert False
raise AssertionError
if self.target_updaters[group] is not None:
self.target_updaters[group].step()
return training_td
Expand Down
35 changes: 21 additions & 14 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
):
self.experiment_config = experiment_config
self.algorithm_name = algorithm_name
self.environment_name = environment_name
self.task_name = task_name
self.model_name = model_name
self.group_map = group_map
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
wandb_kwargs={
"group": task_name,
"project": "benchmarl",
"id": exp_name,
},
)
)
Expand All @@ -75,6 +77,7 @@ def log_hparams(self, **kwargs):
"algorithm_name": self.algorithm_name,
"model_name": self.model_name,
"task_name": self.task_name,
"environment_name": self.environment_name,
"seed": self.seed,
}
)
Expand All @@ -85,17 +88,15 @@ def log_collection(
batch: TensorDictBase,
total_frames: int,
step: int,
):
) -> float:
if not len(self.loggers) and self.json_writer is None:
return
to_log = {}
group_returns = {}
json_metrics = {}
for group in self.group_map.keys():
episode_reward = self._get_episode_reward(group, batch)
done = self._get_done(group, batch)
group_returns[group + "_return"] = episode_reward.mean(-2)[
done.any(-2)
].tolist()
json_metrics[group + "_return"] = episode_reward.mean(-2)[done.any(-2)]
reward = self._get_reward(group, batch)
episode_reward = episode_reward[done]
to_log.update(
Expand All @@ -115,11 +116,16 @@ def log_collection(
for key, value in batch.get((group, "info")).items()
}
)
mean_group_return = torch.stack(
[value for key, value in json_metrics.items()], dim=0
).mean(0)
json_metrics["return"] = mean_group_return
if self.json_writer is not None:
self.json_writer.write(
metrics=group_returns, total_frames=total_frames, step=step
metrics=json_metrics, total_frames=total_frames, step=step
)
self.log(to_log, step=step)
return mean_group_return.mean().item()

def log_training(self, group: str, training_td: TensorDictBase, step: int):
if not len(self.loggers):
Expand Down Expand Up @@ -172,14 +178,7 @@ def log_evaluation(
).unsqueeze(0)
for logger in self.loggers:
if isinstance(logger, WandbLogger):
import wandb

logger.experiment.log(
{
"eval/video": wandb.Video(vid, fps=20, format="mp4"),
},
commit=False,
)
logger.log_video("eval/video", vid, fps=20, commit=False)
else:
logger.log_video("eval_video", vid, step=step)

Expand All @@ -196,6 +195,13 @@ def log(self, dict_to_log: Dict, step: int = None):
for key, value in dict_to_log.items():
logger.log_scalar(key.replace("/", "_"), value, step=step)

def finish(self):
for logger in self.loggers:
if isinstance(logger, WandbLogger):
import wandb

wandb.finish()

def _get_reward(
self, group: str, td: TensorDictBase, remove_agent_dim: bool = False
):
Expand Down Expand Up @@ -251,6 +257,7 @@ def __init__(
}

def write(self, total_frames: int, metrics: Dict[str, Any], step: int):
metrics = {k: val.tolist() for k, val in metrics.items()}
metrics.update({"step_count": total_frames})
step_str = f"step_{step}"
if step_str in self.run_data:
Expand Down
6 changes: 4 additions & 2 deletions benchmarl/hydra_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ def load_model_from_hydra_config(cfg: DictConfig) -> ModelConfig:

@hydra.main(version_base=None, config_path="conf", config_name="config")
def hydra_experiment(cfg: DictConfig) -> None:
print("Loaded config:")
print(OmegaConf.to_yaml(cfg))
hydra_choices = HydraConfig.get().runtime.choices
task_name = hydra_choices.task
print(f"\nAlgorithm: {hydra_choices.algorithm}, Task: {task_name}")
print("\nLoaded config:\n")
print(OmegaConf.to_yaml(cfg))

experiment = load_experiment_from_hydra_config(
cfg,
task_name=task_name,
Expand Down

0 comments on commit b176244

Please sign in to comment.