Skip to content

Commit

Permalink
Merge pull request #195 from OpenBMB/dev
Browse files Browse the repository at this point in the history
add grad scale for optim_manager && fix workflow action
  • Loading branch information
MayDomine committed Apr 26, 2024
2 parents 6670f5c + 151f679 commit b903a31
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ on:
branches:
- 'dev'
- 'main'
push:
branches:
- 'dev'

jobs:
build-archive-wheel:

uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main
secrets:
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
secrets: inherit

publish:
needs: build-archive-wheel
Expand Down
6 changes: 5 additions & 1 deletion bmtrain/optim/optim_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self,
loss_scale_steps : int = 1024,
min_loss_scale = 1,
max_loss_scale = float("inf"),
grad_scale : Optional[int] = None,
):
if loss_scale is not None:
self.loss_scale = loss_scale
Expand All @@ -64,6 +65,9 @@ def __init__(self,
self.loss_scale_steps = loss_scale_steps
self.min_loss_scale = min_loss_scale
self.max_loss_scale = max_loss_scale
if grad_scale is None:
grad_scale = config['zero_size']
self.grad_scale = grad_scale

self.optimizers = []
self.lr_schedulers = []
Expand All @@ -85,7 +89,7 @@ def add_optimizer(

def scale_loss(self, loss : torch.Tensor) -> torch.Tensor:

return loss * (self.loss_scale / (config['world_size']//(config['tp_size']*config['pipe_size']))) # loss scale
return loss * ( self.loss_scale / self.grad_scale ) # loss scale

def backward(self, loss : torch.Tensor):
"""
Expand Down
8 changes: 6 additions & 2 deletions bmtrain/synchronize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from . import distributed, nccl
from .global_var import config
import warnings
from typing import Optional

def synchronize():
"""
Expand All @@ -24,14 +25,17 @@ def wait_loader():
config['calc_stream'].record_event(config['load_event'])


def sum_loss(loss : torch.Tensor):
def sum_loss(loss : torch.Tensor, comm: Optional[nccl.NCCLCommunicator] = None):
"""
Sum the loss across all workers.
This is a helper function to reduce the loss across all workers.
"""
if comm is None:
comm = config['comm']
warnings.warn("bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead.", DeprecationWarning)
return distributed.all_reduce(loss, "sum") / config['world_size']

return distributed.all_reduce(loss, "avg", comm)

def gather_result(result: torch.Tensor):
warnings.warn("bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", DeprecationWarning)
Expand Down

0 comments on commit b903a31

Please sign in to comment.