Skip to content

Commit

Permalink
fix(optim): move apply_updates in grad context
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed May 22, 2024
1 parent 6c5304a commit 5f9b918
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
15 changes: 7 additions & 8 deletions torchopt/optim/func/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,13 @@ def step(
with torch.enable_grad():
# Step parameters only
grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)

updates, self.optim_state = self.impl.update(
grads,
self.optim_state,
params=params,
inplace=inplace,
)
return apply_updates(params, updates, inplace=inplace)
updates, self.optim_state = self.impl.update(
grads,
self.optim_state,
params=params,
inplace=inplace,
)
return apply_updates(params, updates, inplace=inplace)

def state_dict(self) -> OptState:
"""Extract the references of the optimizer states.
Expand Down
10 changes: 7 additions & 3 deletions torchopt/optim/meta/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,14 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals
with torch.enable_grad():
# Step parameters only
grads = torch.autograd.grad(loss, flat_params, create_graph=True, allow_unused=True)
updates, new_state = self.impl.update(
grads,
state,
params=flat_params,
inplace=False,
)
flat_new_params = apply_updates(flat_params, updates, inplace=False)

updates, new_state = self.impl.update(grads, state, params=flat_params, inplace=False)

flat_new_params = apply_updates(flat_params, updates, inplace=False)
new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment]
container_treespec,
flat_new_params,
Expand Down

0 comments on commit 5f9b918

Please sign in to comment.