From 11c55c29c8c2f4dfaab623e826a8d71778097f47 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Mon, 17 Jun 2024 11:06:54 +0200 Subject: [PATCH] [Feature] PPO minibatch advantage (#100) --- benchmarl/algorithms/ippo.py | 54 +++++++++++++++++++++-------- benchmarl/algorithms/mappo.py | 54 +++++++++++++++++++++-------- benchmarl/conf/algorithm/ippo.yaml | 1 + benchmarl/conf/algorithm/mappo.yaml | 1 + benchmarl/experiment/experiment.py | 6 ++-- 5 files changed, 84 insertions(+), 32 deletions(-) diff --git a/benchmarl/algorithms/ippo.py b/benchmarl/algorithms/ippo.py index 45e7e7c8..aac2cd88 100644 --- a/benchmarl/algorithms/ippo.py +++ b/benchmarl/algorithms/ippo.py @@ -36,6 +36,9 @@ class Ippo(Algorithm): choices: "softplus", "exp", "relu", "biased_softplus_1"; use_tanh_normal (bool): if ``True``, use TanhNormal as the continuyous action distribution with support bound to the action domain. Otherwise, an IndependentNormal is used. + minibatch_advantage (bool): if ``True``, advantage computation is perfomend on minibatches of size + ``experiment.config.on_policy_minibatch_size`` instead of the full + ``experiment.config.on_policy_collected_frames_per_batch``, this helps not exploding memory usage """ @@ -49,6 +52,7 @@ def __init__( lmbda: float, scale_mapping: str, use_tanh_normal: bool, + minibatch_advantage: bool, **kwargs ): super().__init__(**kwargs) @@ -61,6 +65,7 @@ def __init__( self.lmbda = lmbda self.scale_mapping = scale_mapping self.use_tanh_normal = use_tanh_normal + self.minibatch_advantage = minibatch_advantage ############################# # Overridden abstract methods @@ -148,15 +153,17 @@ def _get_policy_for_loss( spec=self.action_spec[group, "action"], in_keys=[(group, "loc"), (group, "scale")], out_keys=[(group, "action")], - distribution_class=IndependentNormal - if not self.use_tanh_normal - else TanhNormal, - distribution_kwargs={ - "min": self.action_spec[(group, "action")].space.low, - "max": self.action_spec[(group, "action")].space.high, - } - if self.use_tanh_normal - else {}, + distribution_class=( + IndependentNormal if not self.use_tanh_normal else TanhNormal + ), + distribution_kwargs=( + { + "min": self.action_spec[(group, "action")].space.low, + "max": self.action_spec[(group, "action")].space.high, + } + if self.use_tanh_normal + else {} + ), return_log_prob=True, log_prob_key=(group, "log_prob"), ) @@ -221,14 +228,30 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase: batch.get(("next", "reward")).unsqueeze(-1).expand((*group_shape, 1)), ) - with torch.no_grad(): - loss = self.get_loss_and_updater(group)[0] - loss.value_estimator( - batch, - params=loss.critic_network_params, - target_params=loss.target_critic_network_params, + loss = self.get_loss_and_updater(group)[0] + if self.minibatch_advantage: + increment = -( + -self.experiment.config.train_minibatch_size(self.on_policy) + // batch.shape[1] ) + else: + increment = batch.batch_size[0] + 1 + last_start_index = 0 + start_index = increment + minibatches = [] + while last_start_index < batch.shape[0]: + minimbatch = batch[last_start_index:start_index] + minibatches.append(minimbatch) + with torch.no_grad(): + loss.value_estimator( + minimbatch, + params=loss.critic_network_params, + target_params=loss.target_critic_network_params, + ) + last_start_index = start_index + start_index += increment + batch = torch.cat(minibatches, dim=0) return batch def process_loss_vals( @@ -285,6 +308,7 @@ class IppoConfig(AlgorithmConfig): lmbda: float = MISSING scale_mapping: str = MISSING use_tanh_normal: bool = MISSING + minibatch_advantage: bool = MISSING @staticmethod def associated_class() -> Type[Algorithm]: diff --git a/benchmarl/algorithms/mappo.py b/benchmarl/algorithms/mappo.py index 0a391d2e..891200ef 100644 --- a/benchmarl/algorithms/mappo.py +++ b/benchmarl/algorithms/mappo.py @@ -40,6 +40,9 @@ class Mappo(Algorithm): choices: "softplus", "exp", "relu", "biased_softplus_1"; use_tanh_normal (bool): if ``True``, use TanhNormal as the continuyous action distribution with support bound to the action domain. Otherwise, an IndependentNormal is used. + minibatch_advantage (bool): if ``True``, advantage computation is perfomend on minibatches of size + ``experiment.config.on_policy_minibatch_size`` instead of the full + ``experiment.config.on_policy_collected_frames_per_batch``, this helps not exploding memory usage """ @@ -53,6 +56,7 @@ def __init__( lmbda: float, scale_mapping: str, use_tanh_normal: bool, + minibatch_advantage: bool, **kwargs ): super().__init__(**kwargs) @@ -65,6 +69,7 @@ def __init__( self.lmbda = lmbda self.scale_mapping = scale_mapping self.use_tanh_normal = use_tanh_normal + self.minibatch_advantage = minibatch_advantage ############################# # Overridden abstract methods @@ -152,15 +157,17 @@ def _get_policy_for_loss( spec=self.action_spec[group, "action"], in_keys=[(group, "loc"), (group, "scale")], out_keys=[(group, "action")], - distribution_class=IndependentNormal - if not self.use_tanh_normal - else TanhNormal, - distribution_kwargs={ - "min": self.action_spec[(group, "action")].space.low, - "max": self.action_spec[(group, "action")].space.high, - } - if self.use_tanh_normal - else {}, + distribution_class=( + IndependentNormal if not self.use_tanh_normal else TanhNormal + ), + distribution_kwargs=( + { + "min": self.action_spec[(group, "action")].space.low, + "max": self.action_spec[(group, "action")].space.high, + } + if self.use_tanh_normal + else {} + ), return_log_prob=True, log_prob_key=(group, "log_prob"), ) @@ -225,14 +232,30 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase: batch.get(("next", "reward")).unsqueeze(-1).expand((*group_shape, 1)), ) - with torch.no_grad(): - loss = self.get_loss_and_updater(group)[0] - loss.value_estimator( - batch, - params=loss.critic_network_params, - target_params=loss.target_critic_network_params, + loss = self.get_loss_and_updater(group)[0] + if self.minibatch_advantage: + increment = -( + -self.experiment.config.train_minibatch_size(self.on_policy) + // batch.shape[1] ) + else: + increment = batch.batch_size[0] + 1 + last_start_index = 0 + start_index = increment + minibatches = [] + while last_start_index < batch.shape[0]: + minimbatch = batch[last_start_index:start_index] + minibatches.append(minimbatch) + with torch.no_grad(): + loss.value_estimator( + minimbatch, + params=loss.critic_network_params, + target_params=loss.target_critic_network_params, + ) + last_start_index = start_index + start_index += increment + batch = torch.cat(minibatches, dim=0) return batch def process_loss_vals( @@ -321,6 +344,7 @@ class MappoConfig(AlgorithmConfig): lmbda: float = MISSING scale_mapping: str = MISSING use_tanh_normal: bool = MISSING + minibatch_advantage: bool = MISSING @staticmethod def associated_class() -> Type[Algorithm]: diff --git a/benchmarl/conf/algorithm/ippo.yaml b/benchmarl/conf/algorithm/ippo.yaml index 2cda60df..2d248845 100644 --- a/benchmarl/conf/algorithm/ippo.yaml +++ b/benchmarl/conf/algorithm/ippo.yaml @@ -11,3 +11,4 @@ loss_critic_type: "l2" lmbda: 0.9 scale_mapping: "biased_softplus_1.0" use_tanh_normal: True +minibatch_advantage: False diff --git a/benchmarl/conf/algorithm/mappo.yaml b/benchmarl/conf/algorithm/mappo.yaml index db194d5f..d889ba48 100644 --- a/benchmarl/conf/algorithm/mappo.yaml +++ b/benchmarl/conf/algorithm/mappo.yaml @@ -12,3 +12,4 @@ loss_critic_type: "l2" lmbda: 0.9 scale_mapping: "biased_softplus_1.0" use_tanh_normal: True +minibatch_advantage: False diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index b45d0e18..58a0e818 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -606,8 +606,10 @@ def _collection_loop(self): training_tds = [] for _ in range(self.config.n_optimizer_steps(self.on_policy)): for _ in range( - self.config.train_batch_size(self.on_policy) - // self.config.train_minibatch_size(self.on_policy) + -( + -self.config.train_batch_size(self.on_policy) + // self.config.train_minibatch_size(self.on_policy) + ) ): training_tds.append(self._optimizer_loop(group)) training_td = torch.stack(training_tds)