diff --git a/bmtrain/optim/adam.py b/bmtrain/optim/adam.py index fdec501b..12bbf9e6 100644 --- a/bmtrain/optim/adam.py +++ b/bmtrain/optim/adam.py @@ -3,6 +3,7 @@ import torch.optim._functional as F from . import _cuda as C from .. import nccl +import inspect class AdamOptimizer(torch.optim.Optimizer): """ @@ -117,6 +118,9 @@ def step(self, closure=None): state['step'] ) else: + other_kwargs = {} + if 'maximize' in inspect.signature(F.adam).parameters: + other_kwargs['maximize'] = False F.adam( [p], [p.grad / self._scale], @@ -130,7 +134,7 @@ def step(self, closure=None): lr=0.0 if state["step"] <= self._hold_steps else group['lr'], weight_decay=group['weight_decay'], eps=group['eps'], - maximize=False + **other_kwargs ) return loss diff --git a/bmtrain/optim/adam_offload.py b/bmtrain/optim/adam_offload.py index 207f0abe..45fa726f 100644 --- a/bmtrain/optim/adam_offload.py +++ b/bmtrain/optim/adam_offload.py @@ -4,6 +4,7 @@ from . import _cuda as G from .. import nccl import torch.optim._functional as F +import inspect class AdamOffloadOptimizer(torch.optim.Optimizer): """ @@ -27,6 +28,7 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0 self._scale = scale self._steps_since_last_scale = 0 self._hold_steps = hold_steps + self._events = {} @property def scale(self): @@ -102,21 +104,22 @@ def step(self, closure=None): # placeholder state["_grad_fp32"] = torch.empty(p.size(), dtype=torch.float32, pin_memory=True) # on host - state["_load_event"] = torch.cuda.Event() + if p not in self._events: + self._events[p] = torch.cuda.Event() - update_params.append((p, state, group['betas'][0], group['betas'][1], group['eps'], group['lr'], group['weight_decay'])) + update_params.append((p, state, self._events[p], group['betas'][0], group['betas'][1], group['eps'], group['lr'], group['weight_decay'])) # transfer parameters to host asynchronously - for param, state, _, _, _, _, _ in update_params: + for param, state, event, _, _, _, _, _ in update_params: if param.dtype == torch.half: state["_grad_fp16"].copy_(param.grad, non_blocking=True) else: state["_grad_fp32"].copy_(param.grad, non_blocking=True) - torch.cuda.current_stream().record_event(state["_load_event"]) + torch.cuda.current_stream().record_event(event) - for param, state, beta1, beta2, eps, lr, weight_decay in update_params: + for param, state, event, beta1, beta2, eps, lr, weight_decay in update_params: # wait for transfer to host - state["_load_event"].synchronize() + event.synchronize() state["step"] += 1 @@ -138,6 +141,10 @@ def step(self, closure=None): param.copy_(state["_param_fp16"], non_blocking=True) else: state["_grad_fp32"].mul_(1.0 / self._scale) + + other_kwargs = {} + if 'maximize' in inspect.signature(F.adam).parameters: + other_kwargs['maximize'] = False F.adam( [state["_param_fp32"]], [state["_grad_fp32"]], @@ -151,7 +158,7 @@ def step(self, closure=None): lr=0.0 if state["step"] <= self._hold_steps else lr, weight_decay=weight_decay, eps=eps, - maximize=False + **other_kwargs ) # transfer parameters back to device asynchronously param.copy_(state["_param_fp32"], non_blocking=True)