Skip to content

Commit

Permalink
FX: remove cuda events from optimizer state_dict. FX: F.adam maximize…
Browse files Browse the repository at this point in the history
… argument
  • Loading branch information
a710128 committed May 12, 2022
1 parent cb06d14 commit f6adb57
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
6 changes: 5 additions & 1 deletion bmtrain/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.optim._functional as F
from . import _cuda as C
from .. import nccl
import inspect

class AdamOptimizer(torch.optim.Optimizer):
"""
Expand Down Expand Up @@ -117,6 +118,9 @@ def step(self, closure=None):
state['step']
)
else:
other_kwargs = {}
if 'maximize' in inspect.signature(F.adam).parameters:
other_kwargs['maximize'] = False
F.adam(
[p],
[p.grad / self._scale],
Expand All @@ -130,7 +134,7 @@ def step(self, closure=None):
lr=0.0 if state["step"] <= self._hold_steps else group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
maximize=False
**other_kwargs
)

return loss
Expand Down
21 changes: 14 additions & 7 deletions bmtrain/optim/adam_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import _cuda as G
from .. import nccl
import torch.optim._functional as F
import inspect

class AdamOffloadOptimizer(torch.optim.Optimizer):
"""
Expand All @@ -27,6 +28,7 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0
self._scale = scale
self._steps_since_last_scale = 0
self._hold_steps = hold_steps
self._events = {}

@property
def scale(self):
Expand Down Expand Up @@ -102,21 +104,22 @@ def step(self, closure=None):
# placeholder
state["_grad_fp32"] = torch.empty(p.size(), dtype=torch.float32, pin_memory=True) # on host

state["_load_event"] = torch.cuda.Event()
if p not in self._events:
self._events[p] = torch.cuda.Event()

update_params.append((p, state, group['betas'][0], group['betas'][1], group['eps'], group['lr'], group['weight_decay']))
update_params.append((p, state, self._events[p], group['betas'][0], group['betas'][1], group['eps'], group['lr'], group['weight_decay']))

# transfer parameters to host asynchronously
for param, state, _, _, _, _, _ in update_params:
for param, state, event, _, _, _, _, _ in update_params:
if param.dtype == torch.half:
state["_grad_fp16"].copy_(param.grad, non_blocking=True)
else:
state["_grad_fp32"].copy_(param.grad, non_blocking=True)
torch.cuda.current_stream().record_event(state["_load_event"])
torch.cuda.current_stream().record_event(event)

for param, state, beta1, beta2, eps, lr, weight_decay in update_params:
for param, state, event, beta1, beta2, eps, lr, weight_decay in update_params:
# wait for transfer to host
state["_load_event"].synchronize()
event.synchronize()

state["step"] += 1

Expand All @@ -138,6 +141,10 @@ def step(self, closure=None):
param.copy_(state["_param_fp16"], non_blocking=True)
else:
state["_grad_fp32"].mul_(1.0 / self._scale)

other_kwargs = {}
if 'maximize' in inspect.signature(F.adam).parameters:
other_kwargs['maximize'] = False
F.adam(
[state["_param_fp32"]],
[state["_grad_fp32"]],
Expand All @@ -151,7 +158,7 @@ def step(self, closure=None):
lr=0.0 if state["step"] <= self._hold_steps else lr,
weight_decay=weight_decay,
eps=eps,
maximize=False
**other_kwargs
)
# transfer parameters back to device asynchronously
param.copy_(state["_param_fp32"], non_blocking=True)
Expand Down

0 comments on commit f6adb57

Please sign in to comment.