Skip to content

Commit

Permalink
fix OOM caused by PipeDreamBlockList.init_param_storage
Browse files Browse the repository at this point in the history
  • Loading branch information
MayDomine committed May 13, 2024
1 parent 0e22e96 commit 88601be
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,8 +664,10 @@ def __init__(self, modules: Iterable[Block], num_hidden=1, use_checkpoint=False)
m.init_param_storage()
partition_modules.append(m)
else:
m.init_param_storage()
del m
#m.init_param_storage()
for name, param in m.named_parameters():
c = OpAllGather.apply(param)
del param
super().__init__(partition_modules, num_hidden, mode=mode)
self.fisrt_module = (self._modules['0'],)
self.last_module = (self._modules[str(len(self._modules) - 1)],)
Expand Down Expand Up @@ -712,9 +714,13 @@ def _add_head(self, module):

def add_head(self, module, use_checkpoint=False):
module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=use_checkpoint)
module.init_param_storage()
if config['topology'].pipe_rank != 0:
for name, param in module.named_parameters():
c = OpAllGather.apply(param)
del param
return DummyForward
else:
module.init_param_storage()
self._add_head(module)
return module

Expand Down Expand Up @@ -766,9 +772,12 @@ def _add_tail(self, module):

def add_tail(self, module, use_checkpoint=False):
module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=use_checkpoint)
module.init_param_storage()
if config['topology'].pipe_rank != config['topology'].pipe_size - 1:
for name, param in module.named_parameters():
c = OpAllGather.apply(param)
del param
return DummyForward
else:
module.init_param_storage()
self._add_tail(module)
return module

0 comments on commit 88601be

Please sign in to comment.