diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index c43cc8ca..90ef568b 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -557,7 +557,6 @@ def _collection_loop(self): initial=self.n_iters_performed, total=self.config.get_max_n_iters(self.on_policy), ) - sampling_start = time.time() if not self.config.collect_with_grad: iterator = iter(self.collector) @@ -568,6 +567,7 @@ def _collection_loop(self): for _ in range( self.n_iters_performed, self.config.get_max_n_iters(self.on_policy) ): + iteration_start = time.time() if not self.config.collect_with_grad: batch = next(iterator) else: @@ -585,7 +585,7 @@ def _collection_loop(self): reset_batch = step_mdp(batch[..., -1]) # Logging collection - collection_time = time.time() - sampling_start + collection_time = time.time() - iteration_start current_frames = batch.numel() self.total_frames += current_frames self.mean_return = self.logger.log_collection( @@ -637,22 +637,8 @@ def _collection_loop(self): if not self.config.collect_with_grad: self.collector.update_policy_weights_() - # Timers + # Training timer training_time = time.time() - training_start - iteration_time = collection_time + training_time - self.total_time += iteration_time - self.logger.log( - { - "timers/collection_time": collection_time, - "timers/training_time": training_time, - "timers/iteration_time": iteration_time, - "timers/total_time": self.total_time, - "counters/current_frames": current_frames, - "counters/total_frames": self.total_frames, - "counters/iter": self.n_iters_performed, - }, - step=self.n_iters_performed, - ) # Evaluation if ( @@ -666,6 +652,20 @@ def _collection_loop(self): self._evaluation_loop() # End of step + iteration_time = time.time() - iteration_start + self.total_time += iteration_time + self.logger.log( + { + "timers/collection_time": collection_time, + "timers/training_time": training_time, + "timers/iteration_time": iteration_time, + "timers/total_time": self.total_time, + "counters/current_frames": current_frames, + "counters/total_frames": self.total_frames, + "counters/iter": self.n_iters_performed, + }, + step=self.n_iters_performed, + ) self.n_iters_performed += 1 self.logger.commit() if ( @@ -674,7 +674,6 @@ def _collection_loop(self): ): self._save_experiment() pbar.update() - sampling_start = time.time() if self.config.checkpoint_at_end: self._save_experiment()