Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] use_vmap=False for SAC #2392

Open
wants to merge 2 commits into
base: gh/vmoens/18/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -3650,6 +3650,7 @@ def _create_seq_mock_data_sac(
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("use_vmap", [False, True])
def test_sac(
self,
delay_value,
Expand All @@ -3659,6 +3660,7 @@ def test_sac(
device,
version,
td_est,
use_vmap,
):
if (delay_actor or delay_qvalue) and not delay_value:
pytest.skip("incompatible config")
Expand Down Expand Up @@ -3687,6 +3689,7 @@ def test_sac(
value_network=value,
num_qvalue_nets=num_qvalue,
loss_function="l2",
use_vmap=use_vmap,
**kwargs,
)

Expand Down Expand Up @@ -3811,6 +3814,68 @@ def test_sac(
p.grad is None or p.grad.norm() == 0.0
), f"target parameter {name} (shape: {p.shape}) has a non-null gradient"

@pytest.mark.parametrize("device", get_default_devices())
def test_sac_vmap_equiv(
self,
device,
version,
delay_value=True,
delay_actor=True,
delay_qvalue=True,
num_qvalue=4,
td_est=None,
):
if (delay_actor or delay_qvalue) and not delay_value:
pytest.skip("incompatible config")

torch.manual_seed(self.seed)
td = self._create_mock_data_sac(device=device)

actor = self._create_mock_actor(device=device)
qvalue = self._create_mock_qvalue(device=device)
if version == 1:
value = self._create_mock_value(device=device)
else:
value = None

kwargs = {}
if delay_actor:
kwargs["delay_actor"] = True
if delay_qvalue:
kwargs["delay_qvalue"] = True
if delay_value:
kwargs["delay_value"] = True

loss_fn_vmap = SACLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
num_qvalue_nets=num_qvalue,
loss_function="l2",
use_vmap=True,
**kwargs,
)
loss_fn_novmap = SACLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
num_qvalue_nets=num_qvalue,
loss_function="l2",
use_vmap=False,
**kwargs,
)
loss_fn_novmap.load_state_dict(loss_fn_vmap.state_dict())

with torch.no_grad(), _check_td_steady(td), pytest.warns(
UserWarning, match="No target network updater"
):
rng_state = torch.random.get_rng_state()
loss_vmap = loss_fn_vmap(td.clone())
torch.random.set_rng_state(rng_state)
loss_novmap = loss_fn_novmap(td.clone())

assert_allclose_td(loss_vmap, loss_novmap)

@pytest.mark.parametrize("delay_value", (True, False))
@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, False))
Expand Down
1 change: 1 addition & 0 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class _AcceptedKeys:

_vmap_randomness = None
default_value_estimator: ValueEstimators = None
use_vmap: bool = True

deterministic_sampling_mode: ExplorationType = ExplorationType.DETERMINISTIC

Expand Down
42 changes: 34 additions & 8 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_ERROR,
_LoopVmapModule,
_reduce,
_vmap_func,
default_value_kwargs,
Expand Down Expand Up @@ -113,6 +114,9 @@ class SACLoss(LossModule):
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
use_vmap (bool, optional): Whether :func:`~torch.vmap` should be used to batch
operations. Defaults to ``True``.
.. note:: Not using ``vmap`` offers greater flexibility but may incur a slower runtime.

Examples:
>>> import torch
Expand Down Expand Up @@ -307,7 +311,9 @@ def __init__(
priority_key: str = None,
separate_losses: bool = False,
reduction: str = None,
use_vmap: bool = True,
) -> None:
self.use_vmap = use_vmap
self._in_keys = None
self._out_keys = None
if reduction is None:
Expand Down Expand Up @@ -407,13 +413,22 @@ def __init__(
self.reduction = reduction

def _make_vmap(self):
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
if self._version == 1:
self._vmap_qnetwork00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
if self.use_vmap:
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
if self._version == 1:
self._vmap_qnetwork00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
else:
self._vmap_qnetworkN0 = _LoopVmapModule(
self.qvalue_network, (None, 0), functional=True
)
if self._version == 1:
self._vmap_qnetwork00 = _LoopVmapModule(
self.qvalue_network, functional=True
)

@property
def target_entropy_buffer(self):
Expand Down Expand Up @@ -579,7 +594,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
else:
loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict_reshape)
loss_value = None
loss_actor, metadata_actor = self._actor_loss(tensordict_reshape)
loss_actor, metadata_actor = self.actor_loss(tensordict_reshape)
loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"])
tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"])
if (loss_actor.shape != loss_qvalue.shape) or (
Expand Down Expand Up @@ -614,9 +629,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def _cached_detached_qvalue_params(self):
return self.qvalue_network_params.detach()

def _actor_loss(
def actor_loss(
self, tensordict: TensorDictBase
) -> Tuple[Tensor, Dict[str, Tensor]]:
"""The loss for the actor.

Args:
tensordict (TensorDictBase): the input data. See :attr:`~.in_keys` for more details
on the required fields.

Returns: a tensor containing the actor loss along with a dictionary of metadata.

"""
with set_exploration_type(
ExplorationType.RANDOM
), self.actor_network_params.to_module(self.actor_network):
Expand All @@ -626,10 +650,12 @@ def _actor_loss(

td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
td_q.set(self.tensor_keys.action, a_reparm)

td_q = self._vmap_qnetworkN0(
td_q,
self._cached_detached_qvalue_params, # should we clone?
)

min_q_logprob = (
td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
)
Expand Down
69 changes: 67 additions & 2 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
import re
import warnings
from enum import Enum
from typing import Iterable, Optional, Union
from typing import Iterable, Optional, Tuple, Union

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict.nn import TensorDictModule
from tensordict.utils import _zip_strict
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.modules import dropout
from torch.utils._pytree import tree_map

try:
from torch import vmap
Expand Down Expand Up @@ -480,7 +482,7 @@ def new_fun(self, netname=None):
return new_fun


def _vmap_func(module, *args, func=None, **kwargs):
def _vmap_func(module, *args, func=None, call_vmap: bool = True, **kwargs):
try:

def decorated_module(*module_args_params):
Expand All @@ -503,6 +505,69 @@ def decorated_module(*module_args_params):
) from err


class _LoopVmapModule(nn.Module):
def __init__(
self,
module: nn.Module,
in_dims: Tuple[int | None] = None,
out_dims: Tuple[int | None] = None,
register_module: bool = False,
functional: bool = False,
):
super().__init__()
self.register_module = register_module
if not register_module:
self.__dict__["module"] = module
else:
self.module = module
self.in_dims = in_dims
if out_dims is not None:
raise NotImplementedError("out_dims not implemented yet.")
self.out_dims = out_dims
self.functional = functional

def forward(self, *args):
n = None
to_rep = []
if self.in_dims is None:
self.in_dims = [0] * len(args)
args = list(args)
for i, (arg, in_dim) in enumerate(_zip_strict(args, self.in_dims)):
if in_dim is not None:
arg = arg.unbind(in_dim)
if n is None:
n = len(arg)
elif n != len(arg):
raise ValueError(
f"The length of the unbound args differs: {n} vs {len(arg)}."
)
args[i] = arg
else:
to_rep.append(i)
args = [
tuple(arg.copy() for _ in range(n)) if i in to_rep else arg
for i, arg in enumerate(args)
]
out = []
n_out = None
for _args in zip(*args):
if self.functional:
with _args[-1].to_module(self.module):
out.append(self.module(*_args[:-1]))
else:
out.append(self.module(*_args))
if n_out is None:
n_out = len(out[-1]) if isinstance(out[-1], tuple) else 1
if n_out > 1:
return tree_map(lambda *x: torch.stack(out, dim=0), *out)
elif n_out == 1:
# We explicitly assume that out can be stacked
result = torch.stack(out, dim=0)
return result
else:
raise ValueError("Could not determine the number of outputs.")


def _reduce(tensor: torch.Tensor, reduction: str) -> Union[float, torch.Tensor]:
"""Reduces a tensor given the reduction method."""
if reduction == "none":
Expand Down
Loading