Skip to content

Commit

Permalink
chore(optim): wrap torch.autograd.grad() with torch.enable_grad()
Browse files Browse the repository at this point in the history
… context
  • Loading branch information
XuehaiPan committed May 21, 2024
1 parent b3f570c commit fa77c1a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
6 changes: 4 additions & 2 deletions torchopt/optim/func/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ def step(
if inplace is None:
inplace = self.inplace

# Step parameter only
grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)
with torch.enable_grad():
# Step parameters only
grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)

updates, self.optim_state = self.impl.update(
grads,
self.optim_state,
Expand Down
22 changes: 9 additions & 13 deletions torchopt/optim/meta/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,26 +72,22 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals
):
flat_params: TupleOfTensors
flat_params, container_treespec = pytree.tree_flatten_as_tuple(param_container) # type: ignore[arg-type]

if isinstance(state, UninitializedState):
state = self.impl.init(flat_params)
grads = torch.autograd.grad(
loss,
flat_params,
create_graph=True,
allow_unused=True,
)
updates, new_state = self.impl.update(
grads,
state,
params=flat_params,
inplace=False,
)
self.state_groups[i] = new_state

with torch.enable_grad():
grads = torch.autograd.grad(loss, flat_params, create_graph=True, allow_unused=True)

updates, new_state = self.impl.update(grads, state, params=flat_params, inplace=False)

flat_new_params = apply_updates(flat_params, updates, inplace=False)
new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment]
container_treespec,
flat_new_params,
)

self.state_groups[i] = new_state
for container, new_param in zip(param_container, new_params):
container.update(new_param)

Expand Down

0 comments on commit fa77c1a

Please sign in to comment.