diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 98200465..866d00d5 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, Iterator, Union, List -from .utils import (round_up, tp_split_tensor) +from .utils import round_up, tp_split_tensor from .global_var import config import torch from . import nccl @@ -10,7 +10,9 @@ import inspect from torch.utils.checkpoint import checkpoint + def storage_type_cuda(storage_type): + """Convert storage_type to cuda storage_type.""" STORAGE_MAP = { torch.FloatStorage: torch.cuda.FloatStorage, torch.DoubleStorage: torch.cuda.DoubleStorage, @@ -33,7 +35,9 @@ def storage_type_cuda(storage_type): raise ValueError("Unknown storage type: {}".format(storage_type)) return STORAGE_MAP[storage_type] -def _get_param_kw(param : DistributedParameter): + +def _get_param_kw(param: DistributedParameter): + """Get DistributedParameter kw name.""" type_name = str(param.dtype).split(".")[-1] grad_name = "_grad" if param.requires_grad else "_nograd" group_name = "" @@ -41,19 +45,20 @@ def _get_param_kw(param : DistributedParameter): group_name = "_g_" + param.group return type_name + grad_name + group_name + class Block(torch.nn.Module): - """ A block containing two memory-saving methods of ZeRO and checkpoint. - For details please refer to `ZeRO `_ and - `Checkpointing `_ . + """A block containing two memory-saving methods of ZeRO and checkpoint. + For details please refer to `ZeRO `_ and + `Checkpointing `_ . Args: inner_module (torch.nn.Module): The module to reduce memory usage. All kinds of modules are supported. use_checkpoint (boolean): use checkpoint or not. Default True. - zero_level (int): 2 (ZeRO-2) indicates that optimizer states and gradients are partitioned across the process, + zero_level (int): 2 (ZeRO-2) indicates that optimizer states and gradients are partitioned across the process, 3 (ZeRO-3) means that the parameters are partitioned one the basis of ZeRO-2. Default 3. initialized (bool): initialized parameter storage. Default False. mode (str): the mode shouled be "PIPE" when runing in pipeline mode, otherwise mode="BLOCK". Default "BLOCK" - + Examples: >>> transformer_block = TransformerBlock(...) >>> block = Block(transformer_block) @@ -61,7 +66,15 @@ class Block(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_level=3, initialized=False, mode="BLOCK"): + + def __init__( + self, + inner_module: torch.nn.Module, + use_checkpoint=True, + zero_level=3, + initialized=False, + mode="BLOCK", + ): super().__init__() self._module = inner_module self._inputs = None @@ -70,17 +83,17 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev self._backward_block_ctx = None self._param_info = [] - self._storage_params : Dict[str, torch.nn.Parameter] = {} + self._storage_params: Dict[str, torch.nn.Parameter] = {} self._storage_info = {} self._ready = False self._use_checkpoint = use_checkpoint self._is_first_layer = True self._is_last_layer = True - self._need_release = True - self._next_module = None #save the next module of self - self._pre_module = None #save the pre module of self - self._mode = mode #BLOCK or PIPE + self._need_release = True + self._next_module = None # save the next module of self + self._pre_module = None # save the pre module of self + self._mode = mode # BLOCK or PIPE self.all_input_no_grad = False self.all_param_no_grad = False self._zero_level = zero_level @@ -88,6 +101,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev self.init_param_storage() def reference(self, block): + """Make this block be a reference of the input Block.""" self._param_info = block._param_info self._storage_params = block._storage_params self._storage_info = block._storage_info @@ -96,13 +110,16 @@ def reference(self, block): self._need_release = False def init_param_storage(self): + """Init param storage.""" # sort parameters by name ordered_parameters = list(self._module.named_parameters()) # calc total number of parameters for name, param in ordered_parameters: if not isinstance(param, DistributedParameter): - raise ValueError("All parameters in checkpoint block must be DistributedParameter.") + raise ValueError( + "All parameters in checkpoint block must be DistributedParameter." + ) storage_type = storage_type_cuda(param.storage_type()) kw_name = _get_param_kw(param) @@ -122,14 +139,14 @@ def init_param_storage(self): "storage_type": storage_type, "requires_grad": param.requires_grad, "group": param.group, - "zero_comm" : zero_comm + "zero_comm": zero_comm, } param_shape = param._original_shape self._storage_info[kw_name]["total"] = round_up( - self._storage_info[kw_name]["total"] + param_shape.numel(), - 512 // param.element_size() + self._storage_info[kw_name]["total"] + param_shape.numel(), + 512 // param.element_size(), # 512 bytes aligned ) @@ -139,14 +156,15 @@ def init_param_storage(self): comm = val["zero_comm"] world_size = nccl.commCount(comm) rank = nccl.commRank(comm) - val["world_size"] = world_size - partition_size = round_up(val["total"], val["world_size"]) // val["world_size"] + val["world_size"] = world_size + partition_size = ( + round_up(val["total"], val["world_size"]) // val["world_size"] + ) val["partition_size"] = partition_size val["begin"] = rank * partition_size - val["end"] = (rank+1) * partition_size + val["end"] = (rank + 1) * partition_size offsets[kw] = 0 - storage_type = val["storage_type"] storage_param_buffer = storage_type(partition_size) @@ -163,7 +181,6 @@ def init_param_storage(self): else: storage_param.requires_grad_(False) - self._storage_params[kw] = storage_param # initialize parameters in module @@ -176,19 +193,21 @@ def init_param_storage(self): param_end = offsets[kw_name] offsets[kw_name] = round_up(offsets[kw_name], 512 // param.element_size()) - self._param_info.append({ - "parameter": param, - "name": name, - "offset": param_st, - "size": param_shape.numel(), - "shape": param_shape, - "kw_name": kw_name, - }) + self._param_info.append( + { + "parameter": param, + "name": name, + "offset": param_st, + "size": param_shape.numel(), + "shape": param_shape, + "kw_name": kw_name, + } + ) # copy values to buffer for normal parameter storage_st = self._storage_info[kw_name]["begin"] storage_end = self._storage_info[kw_name]["end"] - + # make parameter contiguous in storage with torch.no_grad(): contiguous_param = OpAllGather.apply(param) @@ -202,18 +221,25 @@ def init_param_storage(self): # copy to offset in buffer storage to_offset_st = offset_st + param_st - storage_st to_offset_end = offset_end + param_st - storage_st - + # copy to buffer # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype d_device = self._storage_params[kw_name].device - param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) + param.data = torch.tensor( + [], dtype=param.dtype, device=param.device + ).set_( + self._storage_params[kw_name].storage(), + to_offset_st, + (to_offset_end - to_offset_st,), + ) self._param_info[-1]["begin"] = to_offset_st self._param_info[-1]["end"] = (to_offset_end - to_offset_st,) setattr(param, "_start_partition", offset_st) setattr(param, "_end_partition", offset_end) - param.data[:] = \ - torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] + param.data[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_( + contiguous_param.storage(), offset_st, (offset_end - offset_st,) + )[:] del contiguous_param else: param.data = torch.tensor([], dtype=param.dtype, device=param.device) @@ -226,26 +252,32 @@ def init_param_storage(self): assert offsets[kw] == self._storage_info[kw]["total"] def set_pre_module(self, pre_module): + """Set pre module for current Block.""" if pre_module is not None: self._pre_module = pre_module pre_module._next_module = self - + def pre_module(self): + """Return pre module of current Block.""" return self._pre_module if not self._is_first_layer else None def next_module(self): + """Return next module of current Block.""" return self._next_module if not self._is_last_layer else None def release_next_module(self, flag): + """Release next module of current Block.""" if self.next_module() is not None: self.next_module().release(flag) def release(self, flag): + """Release cuurent block ctx.""" if self._need_release and self._backward_block_ctx is not None: self._backward_block_ctx.exit(flag, True) - config['load_stream'].record_event(config['load_event']) + config["load_stream"].record_event(config["load_event"]) def pre_hook(self, *args): + """Hook function before forward.""" grad_tensors = [] grad_index = [] arg_list = list(args) @@ -262,7 +294,7 @@ def pre_hook(self, *args): if self._mode != "PIPE" and len(grad_tensors) == 0: self.all_param_no_grad = True for param in self._param_info: - if param['parameter'].requires_grad: + if param["parameter"].requires_grad: self.all_param_no_grad = False break self.all_input_no_grad = True @@ -271,14 +303,15 @@ def pre_hook(self, *args): return arg_list def post_hook(self, out): - tuple_out = (out, ) if isinstance(out, torch.Tensor) else out + """Hook function after forward.""" + tuple_out = (out,) if isinstance(out, torch.Tensor) else out post_out = hook_func.PostHookFunc.apply(self, *tuple_out) if isinstance(out, torch.Tensor) and isinstance(post_out, tuple): return post_out[0] post_out = tuple(post_out) return post_out - def forward(self, *args): + def forward(self, *args): arg_list = self.pre_hook(*args) if self.all_input_no_grad and not self.all_param_no_grad: @@ -286,14 +319,16 @@ def forward(self, *args): return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list) if self._use_checkpoint: - out = checkpoint(self._module, *arg_list, use_reentrant=not self.all_input_no_grad) + out = checkpoint( + self._module, *arg_list, use_reentrant=not self.all_input_no_grad + ) else: out = self._module(*arg_list) return self.post_hook(out) - def __getattr__(self,name:str): - if name=="_module": + def __getattr__(self, name: str): + if name == "_module": return self._module return getattr(self._module, name) @@ -301,7 +336,7 @@ def __setattr__(self, name, value): object.__setattr__(self, name, value) def __getattribute__(self, name: str): - if name=="_parameters": + if name == "_parameters": return self._module._parameters return super().__getattribute__(name) @@ -310,15 +345,25 @@ def __delattr__(self, name): def _save_to_state_dict(self, destination, prefix, keep_vars): raise RuntimeError("._save_to_state_dict() of Block should not be called") - - def state_dict(self, destination=None, prefix='', keep_vars=False): + + def state_dict(self, destination=None, prefix="", keep_vars=False): # gather here with torch.no_grad(): with ZeroContext(self): - return self._module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): + return self._module.state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): all_keys = [] for it in self._param_info: key = prefix + it["name"] @@ -326,18 +371,23 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if key in state_dict: # load here input_param = state_dict[key] - param = it['parameter'] + param = it["parameter"] tp_mode = param._tp_mode if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() - verify_shape = torch.Size(it["shape"] if not tp_mode else param._tp_original_shape) + verify_shape = torch.Size( + it["shape"] if not tp_mode else param._tp_original_shape + ) if input_param.shape != verify_shape: - error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.' - .format(key, input_param.shape, verify_shape)) + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format( + key, input_param.shape, verify_shape + ) + ) continue - + param_st = it["offset"] param_end = it["offset"] + it["size"] kw_name = it["kw_name"] @@ -349,17 +399,19 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, continue if param_end <= storage_st: continue - + # copy to buffer verify_size = verify_shape.numel() assert input_param.numel() == verify_size - contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous() + contiguous_param = ( + input_param.to(it["parameter"].dtype).cuda().contiguous() + ) tp_split_dim = param._tp_split_dim if tp_mode and tp_split_dim >= 0: contiguous_param = tp_split_tensor(contiguous_param, tp_split_dim) - + offset_st = max(storage_st - param_st, 0) offset_end = min(storage_end - param_st, contiguous_param.numel()) assert offset_st < offset_end @@ -371,8 +423,15 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype d_device = self._storage_params[kw_name].device - torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \ - torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] + torch.tensor([], dtype=d_dtype, device=d_device).set_( + self._storage_params[kw_name].storage(), + to_offset_st, + (to_offset_end - to_offset_st,), + )[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_( + contiguous_param.storage(), offset_st, (offset_end - offset_st,) + )[ + : + ] del contiguous_param elif strict: missing_keys.append(key) @@ -385,28 +444,49 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, input_param = state_dict[key] is_param_lazy = torch.nn.parameter.is_lazy(param) # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ - if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + if ( + not is_param_lazy + and len(param.shape) == 0 + and len(input_param.shape) == 1 + ): input_param = input_param[0] - if not is_param_lazy and not isinstance(param, DistributedParameter) and input_param.shape != param.shape: + if ( + not is_param_lazy + and not isinstance(param, DistributedParameter) + and input_param.shape != param.shape + ): # local shape should match the one in checkpoint - error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.' - .format(key, input_param.shape, param.shape)) + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format( + key, input_param.shape, param.shape + ) + ) continue - if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != param._original_shape: - error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.' - .format(key, input_param.shape, param.shape)) + if ( + not is_param_lazy + and isinstance(param, DistributedParameter) + and input_param.shape != param._original_shape + ): + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format( + key, input_param.shape, param.shape + ) + ) try: with torch.no_grad(): param._copy_data(input_param) except Exception as ex: - error_msgs.append('While copying the parameter named "{}", ' - 'whose dimensions in the model are {} and ' - 'whose dimensions in the checkpoint are {}, ' - 'an exception occurred : {}.' - .format(key, param.size(), input_param.size(), ex.args)) + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format( + key, param.size(), input_param.size(), ex.args + ) + ) elif strict: missing_keys.append(key) @@ -415,8 +495,11 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, for key in state_dict.keys(): if key.startswith(prefix) and key not in all_keys: unexpected_keys.append(key) - + def grouped_parameters(self): + """ + Yield group params in storage params. + """ ret = {} for kw, val in self._storage_info.items(): if val["group"] not in ret: @@ -431,9 +514,14 @@ def init_parameters(self): """ for it in self._param_info: param = it["parameter"] - if isinstance(param, DistributedParameter) and param._init_method is not None: + if ( + isinstance(param, DistributedParameter) + and param._init_method is not None + ): # initialzie here - tmp_tensor = torch.empty(param._tp_original_shape, device=param.device, dtype=param.dtype) + tmp_tensor = torch.empty( + param._tp_original_shape, device=param.device, dtype=param.dtype + ) param._init_method(tmp_tensor) param_st = it["offset"] param_end = it["offset"] + it["size"] @@ -446,34 +534,38 @@ def init_parameters(self): continue if param_end <= storage_st: continue - + if param._tp_mode and param._tp_split_dim >= 0: tmp_tensor = tp_split_tensor(tmp_tensor, param._tp_split_dim) # copy to buffer assert tmp_tensor.is_contiguous() and it["size"] == tmp_tensor.numel() - - offset_st = max(storage_st - param_st, 0) - offset_end = min(storage_end - param_st, tmp_tensor.numel()) + + offset_st = max(storage_st - param_st, 0) + offset_end = min(storage_end - param_st, tmp_tensor.numel()) assert offset_st < offset_end # copy to buffer # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype d_device = self._storage_params[kw_name].device - param.data[:] = \ - torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:] + param.data[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_( + tmp_tensor.storage(), offset_st, (offset_end - offset_st,) + )[:] del tmp_tensor - - def _named_members(self, get_members_fn, prefix='', recurse=True, **kwargs): + def _named_members(self, get_members_fn, prefix="", recurse=True, **kwargs): r"""Helper method for yielding various names + members of modules.""" - - #compitibity with torch 2.0 - if "remove_duplicate" in inspect.signature(torch.nn.Module._named_members).parameters and "remove_duplicate" not in kwargs: - kwargs['remove_duplicate'] = True + + # compitibity with torch 2.0 + if ( + "remove_duplicate" + in inspect.signature(torch.nn.Module._named_members).parameters + and "remove_duplicate" not in kwargs + ): + kwargs["remove_duplicate"] = True return self._module._named_members(get_members_fn, prefix, recurse, **kwargs) - - def named_modules(self, memo = None, prefix: str = '', remove_duplicate: bool = True): + + def named_modules(self, memo=None, prefix: str = "", remove_duplicate: bool = True): r"""Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself. @@ -514,23 +606,24 @@ def named_modules(self, memo = None, prefix: str = '', remove_duplicate: bool = for name, module in self._module._modules.items(): if module is None: continue - submodule_prefix = prefix + ('.' if prefix else '') + name + submodule_prefix = prefix + ("." if prefix else "") + name for m in module.named_modules(memo, submodule_prefix, remove_duplicate): yield m def named_children(self): return self._module.named_children() - + def train(self, mode: bool = True): self._module.train(mode) def eval(self): self._module.eval() - + def __repr__(self): return self._module.__repr__() -def _block_wrapper(module, module_dict:dict, mode="BLOCK"): + +def _block_wrapper(module, module_dict: dict, mode="BLOCK"): if not isinstance(module, Block): in_block = id(module) in module_dict new_module = Block(module, initialized=in_block, mode=mode) @@ -540,14 +633,17 @@ def _block_wrapper(module, module_dict:dict, mode="BLOCK"): module_dict[id(module)] = new_module else: if mode == "PIPE" and module._mode != "PIPE": - assert False, "You must be set mode=\"PIPE\" in bmt.Block when use PipelineTransformerBlockList!" + assert ( + False + ), 'You must be set mode="PIPE" in bmt.Block when use PipelineTransformerBlockList!' if id(module._module) in module_dict: assert False, "Duplicate bmt.Block not supported in same block list!" else: new_module = module module_dict[id(module._module)] = new_module return new_module - + + class TransformerBlockList(torch.nn.Module): r""" TransformerBlockList is a list of bmt.Block. @@ -567,11 +663,12 @@ class TransformerBlockList(torch.nn.Module): >>> hidden_state = transformer_module_list(hidden_state, ...) """ + _modules: Dict[str, Block] def __init__(self, modules: Iterable[Block], num_hidden=1) -> None: super().__init__() - + self._modules = {} pre_module = None module_dict = {} @@ -586,10 +683,10 @@ def __init__(self, modules: Iterable[Block], num_hidden=1) -> None: self.add_module(str(i), module) self._modules[str(0)]._is_first_layer = True - self._modules[str(len(modules)-1)]._is_last_layer = True - + self._modules[str(len(modules) - 1)]._is_last_layer = True + self.num_hidden = num_hidden - + def __len__(self) -> int: return len(self._modules) @@ -599,25 +696,27 @@ def __iter__(self) -> Iterator[Block]: def __getitem__(self, index: Union[int, str]) -> Block: return self._modules[str(index)] - def forward(self, *args, return_hidden_states = False): + def forward(self, *args, return_hidden_states=False): self.return_hidden_states = return_hidden_states hidden_states = [] for i in range(len(self)): if return_hidden_states: - for hidden_state in args[:self.num_hidden]: + for hidden_state in args[: self.num_hidden]: hidden_states.append(hidden_state) outputs = self._modules[str(i)]._call_impl(*args) if not isinstance(outputs, tuple): - outputs = (outputs, ) - args = outputs + args[self.num_hidden:] + outputs = (outputs,) + args = outputs + args[self.num_hidden :] if return_hidden_states: hidden_states = [ - torch.stack(hidden_states[i::self.num_hidden], dim=0) + torch.stack(hidden_states[i :: self.num_hidden], dim=0) for i in range(self.num_hidden) ] if return_hidden_states: return outputs + tuple(hidden_states) else: - return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] + return ( + tuple(outputs[: self.num_hidden]) if self.num_hidden > 1 else outputs[0] + ) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 2c6108b0..577331a2 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -2,58 +2,70 @@ from .global_var import config from .zero_context import ZeroContext + def zero_pre_forward(module, inputs): + """Helper function for using ZeroContext to gather parmas before forward.""" enter = True pipe = False if module._mode == "PIPE": enter = module._micro_idx == 0 pipe = True if enter: - zero_level = module._zero_level + zero_level = module._zero_level forward_flag = 1 if zero_level == 2 else 0 if zero_level == 2 and not module._need_release: - forward_flag = 2 # repeating forward in same layer - if module.all_param_no_grad: #only forward + forward_flag = 2 # repeating forward in same layer + if module.all_param_no_grad: # only forward forward_flag = 0 module._forward_block_ctx = ZeroContext(module, module._layer_dict, pipe=pipe) module._forward_block_ctx.enter(forward_flag) + def zero_post_forward(module, inputs, outputs): + """Helper function for module _forwar_block_ctx weather exits after forward.""" forward_flag = 1 if module._zero_level == 2 else 0 if module.all_param_no_grad: forward_flag = 0 exit = True if module._mode == "PIPE": - exit = module._micro_idx == config['micros'] - 1 + exit = module._micro_idx == config["micros"] - 1 if exit: module._forward_block_ctx.exit(forward_flag) + def zero_pre_backward(module, grad_outputs): + """Helper function for using ZeroContext to init grad buffer before backward.""" backward_flag = 2 if module._zero_level == 2 else 0 if module._mode != "PIPE": module._backward_block_ctx = ZeroContext(module, module._layer_dict) module._backward_block_ctx.enter(backward_flag, True) module.release_next_module(backward_flag) else: - if module._micro_idx == config['micros'] - 1: - module._backward_block_ctx = ZeroContext(module, module._layer_dict, pipe=True) + if module._micro_idx == config["micros"] - 1: + module._backward_block_ctx = ZeroContext( + module, module._layer_dict, pipe=True + ) module._backward_block_ctx.enter(backward_flag, True) + def zero_post_backward(module, grad_inputs, grad_outputs): + """Helper function for module weather release after backward.""" backward_flag = 2 if module._zero_level == 2 else 0 if module._mode != "PIPE": - if module._is_first_layer: + if module._is_first_layer: module.release(backward_flag) else: if module._micro_idx == 0: module.release(backward_flag) module._micro_idx -= 1 + class OneStepNoGradFunc(torch.autograd.Function): """ - requires_grad = False for all inputs + Requires_grad = False for all inputs. """ + @staticmethod def forward(ctx, module, placeholder, *x): ctx.x = x @@ -80,7 +92,8 @@ def backward(ctx, grads): grads = [] for _ in x: grads.append(None) - return None, None, *grads + return None, None, *grads + class PreHookFunc(torch.autograd.Function): @staticmethod @@ -94,6 +107,7 @@ def backward(ctx, *grads): zero_post_backward(ctx.module, grads, None) return None, *grads + class PostHookFunc(torch.autograd.Function): @staticmethod def forward(ctx, module, *out): diff --git a/bmtrain/init.py b/bmtrain/init.py index 69273c09..601d617e 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -10,13 +10,14 @@ from . import nccl from .synchronize import synchronize + def init_distributed( - init_method : str = "env://", - seed : int = 0, - pipe_size: int = -1, - num_micro_batches: int = None, - tp_size : int = 1, - ): + init_method: str = "env://", + seed: int = 0, + pipe_size: int = -1, + num_micro_batches: int = None, + tp_size: int = 1, +): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. It must be called before any other distributed functions. @@ -24,17 +25,17 @@ def init_distributed( Args: seed (int): The random seed. pipe_size (int) : pipe_size means that all processes will be divided into pipe_size groups - num_micro_batches (int) : means that the input batchs will be divided into num_micro_batches small batches. used in pipeline mode. + num_micro_batches (int) : means that the input batchs will be divided into num_micro_batches small batches. used in pipeline mode. tp_size (int) : tp_size means the size of each of tensor parallel group - **init_distributed** reads the following environment variables: - + **init_distributed** reads the following environment variables: + * `WORLD_SIZE`: The total number gpus in the distributed training. * `RANK`: The global rank of the current gpu. From 0 to `WORLD_SIZE - 1`. * `MASTER_ADDR`: The address of the master node. * `MASTER_PORT`: The port of the master node. * `LOCAL_RANK`: The local rank of the current gpu. - + Normally, all the environments variables above are setted by the pytorch distributed launcher. **Note**: Do not use any functions in torch.distributed package including `torch.distributed.init_process_group` . @@ -47,18 +48,18 @@ def init_distributed( local_rank = int(os.environ.get("LOCAL_RANK", "0")) rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_size = int(os.environ.get("LOCAL_WORLD_SIZE","1")) + local_size = int(os.environ.get("LOCAL_WORLD_SIZE", "1")) if "MASTER_ADDR" not in os.environ: - os.environ["MASTER_ADDR"]="localhost" + os.environ["MASTER_ADDR"] = "localhost" if "MASTER_PORT" not in os.environ: - os.environ["MASTER_PORT"]="10010" + os.environ["MASTER_PORT"] = "10010" addr = os.environ["MASTER_ADDR"] port = os.environ["MASTER_PORT"] - master = addr+":"+port + master = addr + ":" + port timeout = datetime.timedelta(seconds=1800) rendezvous_iterator = dist.rendezvous( init_method, rank, world_size, timeout=timeout - ) + ) store, rank, world_size = next(rendezvous_iterator) store.set_timeout(timeout) @@ -75,129 +76,159 @@ def init_distributed( config["load_stream"] = torch.cuda.Stream(priority=-1) config["tp_comm_stream"] = torch.cuda.Stream(priority=-1) config["pp_comm_stream"] = torch.cuda.Stream(priority=-1) - config['barrier_stream'] = torch.cuda.Stream() + config["barrier_stream"] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() config["tp_size"] = tp_size if tp_size > 0 else 1 config["topology"] = topology(config) - config["zero_rank"] = config['topology'].get_group_rank("zero") - config["tp_rank"] = config['topology'].get_group_rank("tp") - config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") + config["zero_rank"] = config["topology"].get_group_rank("zero") + config["tp_rank"] = config["topology"].get_group_rank("tp") + config["tp_zero_rank"] = config["topology"].get_group_rank("tp_zero") config["save_param_to_cpu"] = True cpus_this_worker = None - + all_available_cpus = sorted(list(os.sched_getaffinity(0))) cpus_per_worker = len(all_available_cpus) // local_size - + if cpus_per_worker < 1: cpus_this_worker = all_available_cpus torch.set_num_threads(1) else: - cpus_this_worker = all_available_cpus[local_rank * cpus_per_worker : (local_rank + 1) * cpus_per_worker] + cpus_this_worker = all_available_cpus[ + local_rank * cpus_per_worker : (local_rank + 1) * cpus_per_worker + ] os.sched_setaffinity(0, cpus_this_worker) - torch.set_num_threads( len(cpus_this_worker) ) + torch.set_num_threads(len(cpus_this_worker)) torch.manual_seed(seed) random.seed(seed) try: import numpy as np + np.random.seed(seed) except ModuleNotFoundError: pass - + if rank == 0: - unique_id : bytes = nccl.getUniqueId() - store.set("BMTRAIN_UNIQUE_ID", unique_id.hex() ) - + unique_id: bytes = nccl.getUniqueId() + store.set("BMTRAIN_UNIQUE_ID", unique_id.hex()) + unique_id = bytes.fromhex(store.get("BMTRAIN_UNIQUE_ID").decode()) - config['comm'] = nccl.commInitRank(unique_id, world_size, rank) - topo = config['topology'] + config["comm"] = nccl.commInitRank(unique_id, world_size, rank) + topo = config["topology"] - if config['pipe_enabled']: - config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"] + if config["pipe_enabled"]: + config["micros"] = ( + num_micro_batches if num_micro_batches else config["pipe_size"] + ) if topo.stage_id == 0: unique_id = nccl.getUniqueId() store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex()) unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) - config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id) + config["pipe_comm"] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id) if topo.pp_zero_id == 0: unique_id = nccl.getUniqueId() - store.set(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}").decode()) - config['pp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['pipe_size'], topo.pp_zero_id) + store.set(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}", unique_id.hex()) + unique_id = bytes.fromhex( + store.get(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}").decode() + ) + config["pp_zero_comm"] = nccl.commInitRank( + unique_id, world_size // config["pipe_size"], topo.pp_zero_id + ) - if config['tp_size'] > 1: + if config["tp_size"] > 1: if topo.tp_id == 0: unique_id = nccl.getUniqueId() store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) - config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) + config["tp_comm"] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) if topo.tp_zero_id == 0: unique_id = nccl.getUniqueId() - store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) - config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['tp_size'], topo.tp_zero_id) + store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex()) + unique_id = bytes.fromhex( + store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode() + ) + config["tp_zero_comm"] = nccl.commInitRank( + unique_id, world_size // config["tp_size"], topo.tp_zero_id + ) - - if config['pipe_size'] > 1 and config['tp_size'] > 1: + if config["pipe_size"] > 1 and config["tp_size"] > 1: if topo.pp_tp_zero_id == 0: unique_id = nccl.getUniqueId() - store.set(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}").decode()) - config['pp_tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.pp_tp_zero_id) + store.set(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}", unique_id.hex()) + unique_id = bytes.fromhex( + store.get(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}").decode() + ) + config["pp_tp_zero_comm"] = nccl.commInitRank( + unique_id, + world_size // (config["pipe_size"] * config["tp_size"]), + topo.pp_tp_zero_id, + ) - config ['zero_comm'] = config['comm'] + config["zero_comm"] = config["comm"] for i in range(world_size): if i == rank: - print_dict("Initialization", { - "rank": rank, - "local_rank": local_rank, - "world_size": world_size, - "local_size": local_size, - "master" : master, - "device": torch.cuda.current_device(), - "cpus": cpus_this_worker - }) + print_dict( + "Initialization", + { + "rank": rank, + "local_rank": local_rank, + "world_size": world_size, + "local_size": local_size, + "master": master, + "device": torch.cuda.current_device(), + "cpus": cpus_this_worker, + }, + ) synchronize() + class topology: - def __init__(self,config): + """A helper class to keep parallel information when using different parallel methods together.""" + + def __init__(self, config): # pipe_idx is the idx of the pipeline in the group - self.rank = config['rank'] + self.rank = config["rank"] pp_size = config["pipe_size"] tp_size = config["tp_size"] world_size = config["world_size"] - assert world_size % (pp_size * tp_size) == 0, "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size" + assert ( + world_size % (pp_size * tp_size) == 0 + ), "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size" dp_size = world_size // (pp_size * tp_size) - config['tp_zero_size'] = dp_size - config['zero_size'] = world_size // pp_size - self.stages = config['pipe_size'] - + config["tp_zero_size"] = dp_size + config["zero_size"] = world_size // pp_size + self.stages = config["pipe_size"] + stage_size = world_size // pp_size for i in range(world_size): - self.pipe_idx = self.rank % stage_size - self.stage_id = self.rank // stage_size + self.pipe_idx = self.rank % stage_size + self.stage_id = self.rank // stage_size self.tp_id = self.rank % tp_size - self.tp_idx = self.rank // tp_size - #pp->zero - self.pp_zero_idx = self.stage_id - self.pp_zero_id = self.pipe_idx - #tp->zero - self.tp_zero_idx = self.tp_id + self.tp_idx = self.rank // tp_size + # pp->zero + self.pp_zero_idx = self.stage_id + self.pp_zero_id = self.pipe_idx + # tp->zero + self.tp_zero_idx = self.tp_id self.tp_zero_id = self.tp_idx - #pp->tp->zero - self.pp_tp_zero_idx = self.stage_id * tp_size + self.tp_id + # pp->tp->zero + self.pp_tp_zero_idx = self.stage_id * tp_size + self.tp_id self.pp_tp_zero_id = self.pipe_idx // tp_size - #only zero + # only zero self.zero_idx = 0 self.zero_id = self.rank + def get_group_id(self, group_name): + """Get group id of different parallel group. - def get_group_id(self,group_name): + Args: + group_name (str): must be one of "pipe", "zero", "tp_zero" or "tp". + """ if group_name == "pipe": return self.pipe_idx elif group_name == "zero": @@ -206,8 +237,13 @@ def get_group_id(self,group_name): return self.tp_zero_idx elif group_name == "tp": return self.tp_idx - - def get_group_rank(self,group_name): + + def get_group_rank(self, group_name): + """Get group rank of different parallel group. + + Args: + group_name (str): must be one of "pipe", "zero", "tp_zero" or "tp". + """ if group_name == "pipe": return self.stage_id elif group_name == "zero": @@ -217,6 +253,6 @@ def get_group_rank(self,group_name): elif group_name == "tp": return self.tp_id + def is_initialized() -> bool: return config["initialized"] - diff --git a/bmtrain/inspect/tensor.py b/bmtrain/inspect/tensor.py index 2c45fdac..9d003f82 100644 --- a/bmtrain/inspect/tensor.py +++ b/bmtrain/inspect/tensor.py @@ -14,14 +14,15 @@ class InspectTensor: You can get the tensors recorded by `record_tensor`. """ + def __init__(self): self.summary = [] - + def _set_summary(self, summary): self._summary = summary for item in summary: - item['prefix'] = "" if item["group"] is None else f'{item["group"]}.' - + item["prefix"] = "" if item["group"] is None else f'{item["group"]}.' + self.summary = [] kw_cnt = {} @@ -54,58 +55,80 @@ def _set_summary(self, summary): for stage in range(stages): if stage_id == stage: - broadcast_object(pipe_cnt, config["pipe_comm"], src = stage) + broadcast_object(pipe_cnt, config["pipe_comm"], src=stage) for k in range(i, j): item = summary[k] kw = f'{item["prefix"]}{item["name"]}' if kw not in kw_cnt: kw_cnt[kw] = 0 - tensor = torch.cat([summary[k+m*(j-i)]['tensor'] for m in range(config['micros'])], dim=0) - grad = torch.cat([summary[k+m*(j-i)]['tensor'].grad for m in range(config['micros'])], dim=0) if item['requires_grad'] and item['tensor'].grad is not None else None - self.summary.append({ - "name": item["name"], - "summary_name": f'{item["prefix"]}{kw_cnt[kw]}.{item["name"]}', - "group": item["group"], - "min": None, - "max": None, - "mean": None, - "std": None, - "shape": (item["shape"][0] * config['micros'],) + item["shape"][1:], - "grad_mean" : None, - "grad_std" : None, - "tensor": tensor, - "grad": grad, - "requires_grad": item["requires_grad"], - "inside_pipe": {"stage_id": stage}, - }) - kw_cnt[kw] += 1 - else: - cnt = broadcast_object({}, config["pipe_comm"], src = stage) - for kw, val in cnt.items(): - if kw not in kw_cnt: - kw_cnt[kw] = 0 - for _ in range(val): - self.summary.append({ + tensor = torch.cat( + [ + summary[k + m * (j - i)]["tensor"] + for m in range(config["micros"]) + ], + dim=0, + ) + grad = ( + torch.cat( + [ + summary[k + m * (j - i)]["tensor"].grad + for m in range(config["micros"]) + ], + dim=0, + ) + if item["requires_grad"] + and item["tensor"].grad is not None + else None + ) + self.summary.append( + { "name": item["name"], "summary_name": f'{item["prefix"]}{kw_cnt[kw]}.{item["name"]}', - "group": None, + "group": item["group"], "min": None, "max": None, "mean": None, "std": None, - "shape": None, - "grad_mean" : None, - "grad_std" : None, - "tensor": None, - "grad": None, - "requires_grad": None, + "shape": (item["shape"][0] * config["micros"],) + + item["shape"][1:], + "grad_mean": None, + "grad_std": None, + "tensor": tensor, + "grad": grad, + "requires_grad": item["requires_grad"], "inside_pipe": {"stage_id": stage}, - }) + } + ) + kw_cnt[kw] += 1 + else: + cnt = broadcast_object({}, config["pipe_comm"], src=stage) + for kw, val in cnt.items(): + if kw not in kw_cnt: + kw_cnt[kw] = 0 + for _ in range(val): + self.summary.append( + { + "name": item["name"], + "summary_name": f'{item["prefix"]}{kw_cnt[kw]}.{item["name"]}', + "group": None, + "min": None, + "max": None, + "mean": None, + "std": None, + "shape": None, + "grad_mean": None, + "grad_std": None, + "tensor": None, + "grad": None, + "requires_grad": None, + "inside_pipe": {"stage_id": stage}, + } + ) kw_cnt[kw] += 1 after_len = len(self.summary) with torch.enable_grad(): - for it in self.summary[before_len: after_len]: + for it in self.summary[before_len:after_len]: if it["tensor"] is not None: has_grad = it["grad"] is not None info = { @@ -114,53 +137,73 @@ def _set_summary(self, summary): "requires_grad": it["requires_grad"], "has_grad": has_grad, } - broadcast_object(info, config["pipe_comm"], src = it["inside_pipe"]["stage_id"]) + broadcast_object( + info, + config["pipe_comm"], + src=it["inside_pipe"]["stage_id"], + ) tensor = it["tensor"] - tensor = broadcast(tensor, it["inside_pipe"]["stage_id"], config["pipe_comm"]) + tensor = broadcast( + tensor, + it["inside_pipe"]["stage_id"], + config["pipe_comm"], + ) grad = it["grad"] else: - info = broadcast_object({}, config["pipe_comm"], src = it["inside_pipe"]["stage_id"]) + info = broadcast_object( + {}, + config["pipe_comm"], + src=it["inside_pipe"]["stage_id"], + ) has_grad = info.pop("has_grad") it.update(info) tensor = torch.empty(it["shape"]).cuda().requires_grad_() - tensor = broadcast(tensor, it["inside_pipe"]["stage_id"], config["pipe_comm"]) + tensor = broadcast( + tensor, + it["inside_pipe"]["stage_id"], + config["pipe_comm"], + ) if has_grad: - grad = torch.empty(it["shape"]).cuda() + grad = torch.empty(it["shape"]).cuda() tensor = tensor.chunk(stages, dim=0)[stage_id].clone() it["tensor"] = tensor if has_grad: - grad = broadcast(grad, it["inside_pipe"]["stage_id"], config["pipe_comm"]) + grad = broadcast( + grad, it["inside_pipe"]["stage_id"], config["pipe_comm"] + ) grad = grad.chunk(stages, dim=0)[stage_id].clone() tensor.grad = grad - it["shape"] = (it["shape"][0]//config["pipe_size"],) + it["shape"][1:] + it["shape"] = (it["shape"][0] // config["pipe_size"],) + it[ + "shape" + ][1:] - i = i + config['micros'] * (j - i) + i = i + config["micros"] * (j - i) else: kw = f'{item["prefix"]}{item["name"]}' if kw not in kw_cnt: kw_cnt[kw] = 0 - self.summary.append({ - "name": item["name"], - "summary_name": f'{item["prefix"]}{kw_cnt[kw]}.{item["name"]}', - "group": item["group"], - "min": None, - "max": None, - "mean": None, - "std": None, - "shape": item["shape"], - "grad_mean" : None, - "grad_std" : None, - "tensor": item["tensor"], - "requires_grad": item["requires_grad"], - "inside_pipe": None, - }) + self.summary.append( + { + "name": item["name"], + "summary_name": f'{item["prefix"]}{kw_cnt[kw]}.{item["name"]}', + "group": item["group"], + "min": None, + "max": None, + "mean": None, + "std": None, + "shape": item["shape"], + "grad_mean": None, + "grad_std": None, + "tensor": item["tensor"], + "requires_grad": item["requires_grad"], + "inside_pipe": None, + } + ) kw_cnt[kw] += 1 i = i + 1 - - def get_summary(self): - """Get the summary of the tensors recorded by `record_tensor`. + r"""Get the summary of the tensors recorded by `record_tensor`. Returns: A list of dicts. Each dict contains the following keys: @@ -186,12 +229,7 @@ def get_summary(self): info = torch.empty(2, dtype=x.dtype, device=x.device) info[0] = x.mean() info[1] = x.var() - nccl.allReduce( - info.storage(), - info.storage(), - "sum", - comm - ) + nccl.allReduce(info.storage(), info.storage(), "sum", comm) info = info / nccl.commCount(comm) x_mean = info[0].cpu().item() x_std = math.sqrt(info[1].cpu().item()) @@ -204,12 +242,7 @@ def get_summary(self): info[1] = x.var() info[2] = x.grad.mean() info[3] = x.grad.var() - nccl.allReduce( - info.storage(), - info.storage(), - "sum", - comm - ) + nccl.allReduce(info.storage(), info.storage(), "sum", comm) info = info / nccl.commCount(comm) x_mean = info[0].cpu().item() x_std = math.sqrt(info[1].cpu().item()) @@ -218,14 +251,9 @@ def get_summary(self): info[0] = x.max() info[1] = -x.min() - nccl.allReduce( - info.storage(), - info.storage(), - 'max', - comm - ) + nccl.allReduce(info.storage(), info.storage(), "max", comm) x_max = info[0].cpu().item() - x_min = - info[1].cpu().item() + x_min = -info[1].cpu().item() summary = { "name": item["summary_name"], @@ -233,25 +261,29 @@ def get_summary(self): "max": x_max, "mean": x_mean, "std": x_std, - "shape": tuple((item["shape"][0] * config["world_size"],) + item["shape"][1:]), - "grad_mean" : grad_mean, - "grad_std" : grad_std + "shape": tuple( + (item["shape"][0] * config["world_size"],) + item["shape"][1:] + ), + "grad_mean": grad_mean, + "grad_std": grad_std, } ret.append(summary) return ret - - def get_tensor(self, name : str, group : Optional[str] = None, index : Optional[int] = None) -> torch.Tensor: + + def get_tensor( + self, name: str, group: Optional[str] = None, index: Optional[int] = None + ) -> torch.Tensor: """Get the tensor recorded by `record_tensor` by name, group and index. Args: name (str): The name of the tensor. group (Optional[str]): The group of the tensor. index (Optional[int]): The index of the tensor. - + Returns: The tensor if found, otherwise None. - + """ group_name_prefix = f"{group}." if group is not None else "" @@ -280,7 +312,7 @@ def __enter__(self) -> InspectTensor: return self._inspector else: raise RuntimeError("InspectTensorManager is already in use") - + def __exit__(self, *args): if not self.prev_val: debug.set("_inspect_tensor", self.prev_val) @@ -288,7 +320,7 @@ def __exit__(self, *args): self._inspector._set_summary(summary) self._inspector = None debug.set("_inspect_hidden_states", []) - + def inspect_tensor() -> InspectTensorManager: """**inspect_tensor** returns a context manager that can be used to get the intermediate results of the model computations and their gradients. @@ -310,38 +342,42 @@ def inspect_tensor() -> InspectTensorManager: return InspectTensorManager() -def record_tensor(x : torch.Tensor, name : str, group = None): + +def record_tensor(x: torch.Tensor, name: str, group=None): """Record the tensor for inspection. Args: x (torch.Tensor): The tensor to be recorded. name (str): The name of the tensor. group (str): The group name of the tensor. - + **Note:** This function is only available in inspect_tensor context. **Note:** Recording too many tensors may cause memory issues. - + """ if isinstance(x, torch.nn.Parameter): raise RuntimeError("Cannot inspect Parameter") - + if not debug.get("_inspect_tensor", False): # do nothing return if x.requires_grad: x.retain_grad() - debug.append("_inspect_hidden_states", { - "name": name, - "group": group, - "min": None, - "max": None, - "mean": None, - "std": None, - "shape": x.shape, - "grad_mean" : None, - "grad_std" : None, - "tensor": x, - "requires_grad": x.requires_grad, - "inside_pipe": None, - }) + debug.append( + "_inspect_hidden_states", + { + "name": name, + "group": group, + "min": None, + "max": None, + "mean": None, + "std": None, + "shape": x.shape, + "grad_mean": None, + "grad_std": None, + "tensor": x, + "requires_grad": x.requires_grad, + "inside_pipe": None, + }, + ) diff --git a/bmtrain/loss/_function.py b/bmtrain/loss/_function.py index 4ac02f5d..6ff3c471 100644 --- a/bmtrain/loss/_function.py +++ b/bmtrain/loss/_function.py @@ -1,7 +1,9 @@ - -from .. import C +from .. import C import torch + CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda + + def has_inf_nan(g_half: torch.Tensor, out: torch.Tensor) -> None: assert out.dtype == torch.uint8, "out must be a uint8 tensor" assert CHECK_INPUT(g_half), "g_fp16 must be contiguous and on cuda" @@ -9,66 +11,120 @@ def has_inf_nan(g_half: torch.Tensor, out: torch.Tensor) -> None: mid = torch.zeros(1024, device=out.device, dtype=out.dtype) stream = torch.cuda.current_stream().cuda_stream if g_half.dtype == torch.float16: - C.has_nan_inf_fp16_launcher(g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream) + C.has_nan_inf_fp16_launcher( + g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream + ) elif g_half.dtype == torch.bfloat16: if not C.is_bf16_supported(): raise NotImplementedError(f"bfloat16 is not supported on current GPU") - C.has_nan_inf_bf16_launcher(g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream) + C.has_nan_inf_bf16_launcher( + g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream + ) else: raise ValueError(f"has_inf_nan not supported for dtype {g_half.dtype}") -def cross_entropy_forward(m: int, n: int, input: torch.Tensor, target: torch.Tensor, - softmax: torch.Tensor, output: torch.Tensor, ignore_index: int) -> None: + +def cross_entropy_forward( + m: int, + n: int, + input: torch.Tensor, + target: torch.Tensor, + softmax: torch.Tensor, + output: torch.Tensor, + ignore_index: int, +) -> None: CHECK_INPUT(input) CHECK_INPUT(target) CHECK_INPUT(softmax) CHECK_INPUT(output) assert target.dtype == torch.int32, "target must be an int tensor" assert output.dtype == torch.float32, "output must be a float tensor" - assert input.numel() == softmax.numel(), "input and softmax must have the same number of elements" - assert target.numel() == output.numel(), "target and output must have the same number of elements" + assert ( + input.numel() == softmax.numel() + ), "input and softmax must have the same number of elements" + assert ( + target.numel() == output.numel() + ), "target and output must have the same number of elements" input_ptr = input.data_ptr() target_ptr = target.data_ptr() softmax_ptr = softmax.data_ptr() output_ptr = output.data_ptr() cuda_stream = torch.cuda.current_stream().cuda_stream if input.dtype == torch.float16: - C.cross_entropy_forward_fp16_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream) + C.cross_entropy_forward_fp16_launcher( + m, + n, + input_ptr, + target_ptr, + softmax_ptr, + output_ptr, + ignore_index, + cuda_stream, + ) elif input.dtype == torch.bfloat16: if not C.is_bf16_supported(): raise NotImplementedError(f"bfloat16 is not supported on current GPU") - C.cross_entropy_forward_bf16_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream) + C.cross_entropy_forward_bf16_launcher( + m, + n, + input_ptr, + target_ptr, + softmax_ptr, + output_ptr, + ignore_index, + cuda_stream, + ) else: raise ValueError(f"cross_entropy_forward not supported for dtype {input.dtype}") -def cross_entropy_backward_inplace(m: int, n: int, grad_output: torch.Tensor, target: torch.Tensor, - x: torch.Tensor, ignore_index: int) -> None: + +def cross_entropy_backward_inplace( + m: int, + n: int, + grad_output: torch.Tensor, + target: torch.Tensor, + x: torch.Tensor, + ignore_index: int, +) -> None: CHECK_INPUT(grad_output) CHECK_INPUT(target) CHECK_INPUT(x) assert grad_output.dtype == torch.float32, "grad_output must be a float tensor" assert target.dtype == torch.int32, "target must be an int tensor" - assert target.numel() == grad_output.numel(), "target and grad_output must have the same number of elements" + assert ( + target.numel() == grad_output.numel() + ), "target and grad_output must have the same number of elements" cuda_stream = torch.cuda.current_stream().cuda_stream grad_output_ptr = grad_output.data_ptr() target_ptr = target.data_ptr() x_ptr = x.data_ptr() if x.dtype == torch.float16: - C.cross_entropy_backward_inplace_fp16_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream) + C.cross_entropy_backward_inplace_fp16_launcher( + m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream + ) elif x.dtype == torch.bfloat16: if not C.is_bf16_supported(): raise NotImplementedError(f"bfloat16 is not supported on current GPU") - C.cross_entropy_backward_inplace_bf16_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream) + C.cross_entropy_backward_inplace_bf16_launcher( + m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream + ) else: - raise ValueError(f"cross_entropy_backward not supported for dtype {input.dtype}") + raise ValueError( + f"cross_entropy_backward not supported for dtype {input.dtype}" + ) + def fused_sumexp(logits: torch.Tensor, max_logits: torch.Tensor) -> torch.Tensor: CHECK_INPUT(logits) CHECK_INPUT(max_logits) assert max_logits.dtype == torch.float32, "max_logits must be float tensor" - assert max_logits.size(0) == logits.size(0), "max_logits must have same size(0) as logits" - sum_exp_logits = torch.empty(logits.size(0), dtype=torch.float32, device=logits.device) + assert max_logits.size(0) == logits.size( + 0 + ), "max_logits must have same size(0) as logits" + sum_exp_logits = torch.empty( + logits.size(0), dtype=torch.float32, device=logits.device + ) m = logits.size(0) n = logits.size(1) cuda_stream = torch.cuda.current_stream().cuda_stream @@ -76,23 +132,34 @@ def fused_sumexp(logits: torch.Tensor, max_logits: torch.Tensor) -> torch.Tensor max_logits_ptr = max_logits.data_ptr() sum_exp_logits_ptr = sum_exp_logits.data_ptr() if logits.dtype == torch.float16: - C.fused_sumexp_fp16_launcher(m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream) + C.fused_sumexp_fp16_launcher( + m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream + ) elif logits.dtype == torch.bfloat16: if not C.is_bf16_supported(): raise NotImplementedError(f"bfloat16 is not supported on current GPU") - C.fused_sumexp_bf16_launcher(m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream) + C.fused_sumexp_bf16_launcher( + m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream + ) else: raise ValueError(f"fused_sumexp not supported for dtype {logits.dtype}") return sum_exp_logits -def fused_softmax_inplace(logits: torch.Tensor, max_logits: torch.Tensor, sum_exp_logits: torch.Tensor) -> None: + +def fused_softmax_inplace( + logits: torch.Tensor, max_logits: torch.Tensor, sum_exp_logits: torch.Tensor +) -> None: CHECK_INPUT(logits) CHECK_INPUT(max_logits) CHECK_INPUT(sum_exp_logits) assert max_logits.dtype == torch.float32, "max_logits must be float tensor" assert sum_exp_logits.dtype == torch.float32, "sum_exp_logits must be float tensor" - assert max_logits.size(0) == logits.size(0), "max_logits must have same size(0) as logits" - assert sum_exp_logits.size(0) == logits.size(0), "sum_exp_logits must have same size(0) as logits" + assert max_logits.size(0) == logits.size( + 0 + ), "max_logits must have same size(0) as logits" + assert sum_exp_logits.size(0) == logits.size( + 0 + ), "sum_exp_logits must have same size(0) as logits" m = logits.size(0) n = logits.size(1) cuda_stream = torch.cuda.current_stream().cuda_stream @@ -100,10 +167,16 @@ def fused_softmax_inplace(logits: torch.Tensor, max_logits: torch.Tensor, sum_ex max_logits_ptr = max_logits.data_ptr() sum_exp_logits_ptr = sum_exp_logits.data_ptr() if logits.dtype == torch.float16: - C.fused_softmax_inplace_fp16_launcher(m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream) + C.fused_softmax_inplace_fp16_launcher( + m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream + ) elif logits.dtype == torch.bfloat16: if not C.is_bf16_supported(): raise NotImplementedError(f"bfloat16 is not supported on current GPU") - C.fused_softmax_inplace_bf16_launcher(m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream) + C.fused_softmax_inplace_bf16_launcher( + m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream + ) else: - raise ValueError(f"fused_softmax_inplace not supported for dtype {logits.dtype}") \ No newline at end of file + raise ValueError( + f"fused_softmax_inplace not supported for dtype {logits.dtype}" + ) diff --git a/bmtrain/lr_scheduler/cosine.py b/bmtrain/lr_scheduler/cosine.py index 5a2e931f..3aed034d 100644 --- a/bmtrain/lr_scheduler/cosine.py +++ b/bmtrain/lr_scheduler/cosine.py @@ -4,12 +4,15 @@ class Cosine(WarmupLRScheduler): r""" - After a warmup period during which learning rate increases linearly between 0 and the start_lr, - The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}` + After a warmup period during which learning rate increases linearly between 0 and the start_lr, + The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}` """ + def get_lr_warmup(self, num_iter) -> float: return self.start_lr * num_iter / self.warmup_iter def get_lr_decay(self, num_iter) -> float: - progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter)) + progress = (num_iter - self.warmup_iter) / max( + 1, (self.end_iter - self.warmup_iter) + ) return max(0.0, self.start_lr * 0.5 * (1.0 + math.cos(progress * math.pi))) diff --git a/bmtrain/lr_scheduler/exponential.py b/bmtrain/lr_scheduler/exponential.py index 32ea5a4e..6cf3240e 100644 --- a/bmtrain/lr_scheduler/exponential.py +++ b/bmtrain/lr_scheduler/exponential.py @@ -3,10 +3,13 @@ class Exponential(WarmupLRScheduler): r""" - After a warmup period during which learning rate increases linearly between 0 and the start_lr, - The decay period performs :math:`\text{lr}=\text{start_lr}\times \gamma ^ {\left(\text{num_iter}-\text{warmup_iter}\right)}` + After a warmup period during which learning rate increases linearly between 0 and the start_lr, + The decay period performs :math:`\text{lr}=\text{start_lr}\times \gamma ^ {\left(\text{num_iter}-\text{warmup_iter}\right)}` """ - def __init__(self, optimizer, start_lr, warmup_iter, end_iter, num_iter, gamma=0.95) -> None: + + def __init__( + self, optimizer, start_lr, warmup_iter, end_iter, num_iter, gamma=0.95 + ) -> None: super().__init__(optimizer, start_lr, warmup_iter, end_iter, num_iter) self.gamma = gamma diff --git a/bmtrain/lr_scheduler/linear.py b/bmtrain/lr_scheduler/linear.py index 119ae1b7..af193dd8 100644 --- a/bmtrain/lr_scheduler/linear.py +++ b/bmtrain/lr_scheduler/linear.py @@ -3,12 +3,17 @@ class Linear(WarmupLRScheduler): r""" - After a warmup period during which learning rate increases linearly between 0 and the start_lr, - The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{end_iter}-\text{num_iter}}{\text{end_iter}-\text{warmup_iter}}` + After a warmup period during which learning rate increases linearly between 0 and the start_lr, + The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{end_iter}-\text{num_iter}}{\text{end_iter}-\text{warmup_iter}}` """ def get_lr_warmup(self, num_iter) -> float: return self.start_lr * num_iter / self.warmup_iter def get_lr_decay(self, num_iter) -> float: - return max(0.0, self.start_lr * (self.end_iter - num_iter) / (self.end_iter - self.warmup_iter)) + return max( + 0.0, + self.start_lr + * (self.end_iter - num_iter) + / (self.end_iter - self.warmup_iter), + ) diff --git a/bmtrain/lr_scheduler/no_decay.py b/bmtrain/lr_scheduler/no_decay.py index 3cff3b55..6f85bf0a 100644 --- a/bmtrain/lr_scheduler/no_decay.py +++ b/bmtrain/lr_scheduler/no_decay.py @@ -1,18 +1,14 @@ - - from .warmup import WarmupLRScheduler + class NoDecay(WarmupLRScheduler): r""" - After a warmup period during which learning rate increases linearly between 0 and the start_lr, - The decay period performs :math:`\text{lr}=\text{start_lr}` + After a warmup period during which learning rate increases linearly between 0 and the start_lr, + The decay period performs :math:`\text{lr}=\text{start_lr}` """ + def get_lr_warmup(self, num_iter) -> float: return self.start_lr * num_iter / self.warmup_iter - + def get_lr_decay(self, num_iter) -> float: return self.start_lr - - - - \ No newline at end of file diff --git a/bmtrain/lr_scheduler/noam.py b/bmtrain/lr_scheduler/noam.py index c1e8c622..8954a64d 100644 --- a/bmtrain/lr_scheduler/noam.py +++ b/bmtrain/lr_scheduler/noam.py @@ -1,14 +1,15 @@ import math from .warmup import WarmupLRScheduler + class Noam(WarmupLRScheduler): r""" - After a warmup period during which performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{num_iter}}{\text{warmup_iter}^{3/2}}`, - The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{1}}{\sqrt{\text{num_iter}}}` + After a warmup period during which performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{num_iter}}{\text{warmup_iter}^{3/2}}`, + The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{1}}{\sqrt{\text{num_iter}}}` """ def get_lr_warmup(self, num_iter) -> float: return self.start_lr / math.sqrt(self.warmup_iter) * num_iter / self.warmup_iter - + def get_lr_decay(self, num_iter) -> float: return self.start_lr / math.sqrt(num_iter) diff --git a/bmtrain/lr_scheduler/warmup.py b/bmtrain/lr_scheduler/warmup.py index 0f08a600..1f9ccc8e 100644 --- a/bmtrain/lr_scheduler/warmup.py +++ b/bmtrain/lr_scheduler/warmup.py @@ -1,5 +1,6 @@ import torch + class WarmupLRScheduler: r"""Base class for learning rate schedulers with warmup. @@ -10,24 +11,31 @@ class WarmupLRScheduler: end_iter (int): number of iterations to stop training num_iter (int): current iteration number """ - - def __init__(self, optimizer : torch.optim.Optimizer, start_lr, warmup_iter, end_iter, num_iter=0) -> None: + + def __init__( + self, + optimizer: torch.optim.Optimizer, + start_lr, + warmup_iter, + end_iter, + num_iter=0, + ) -> None: self.start_lr = start_lr self.warmup_iter = warmup_iter self.end_iter = end_iter self.optimizer = optimizer self.num_iter = num_iter self._current_lr = None - + self.step(self.num_iter) - + def get_lr_warmup(self, num_iter) -> float: ... def get_lr_decay(self, num_iter) -> float: ... def get_lr(self): assert self.num_iter >= 0 - + if self.num_iter < self.warmup_iter: return self.get_lr_warmup(self.num_iter) else: @@ -37,7 +45,7 @@ def get_lr(self): def current_lr(self): return self._current_lr - def step(self, num_iter = None) -> None: + def step(self, num_iter=None) -> None: if num_iter is None: num_iter = self.num_iter + 1 self.num_iter = num_iter @@ -45,14 +53,14 @@ def step(self, num_iter = None) -> None: lr = self.get_lr() self._current_lr = lr for group in self.optimizer.param_groups: - group['lr'] = lr - + group["lr"] = lr + def state_dict(self): return { "start_lr": self.start_lr, "warmup_iter": self.warmup_iter, "end_iter": self.end_iter, - "num_iter": self.num_iter + "num_iter": self.num_iter, } def load_state_dict(self, state_dict): @@ -62,4 +70,3 @@ def load_state_dict(self, state_dict): self.num_iter = state_dict["num_iter"] self.step(self.num_iter) - diff --git a/bmtrain/nn/column_parallel_linear.py b/bmtrain/nn/column_parallel_linear.py index a432d798..e1ede115 100644 --- a/bmtrain/nn/column_parallel_linear.py +++ b/bmtrain/nn/column_parallel_linear.py @@ -3,35 +3,78 @@ import bmtrain as bmt from bmtrain.global_var import config -from .parallel_linear_func import ( - OpParallelLinear, - ReduceType) +from .parallel_linear_func import OpParallelLinear, ReduceType + class ColumnParallelLinear(bmt.DistributedModule): - def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, gather_output=False, gather_input=True, async_gather_chunks=2) -> None: + """Tensor Parallel use cloumn partition for Linear. + + Args: + in_features (int): in_features size. + out_features (int): out_features size. + bias (bool): whether use bias. + dtype : data type. + gather_ouput (bool): whether gather output after compute. + gather_input (bool): whether gather input before compute. + async_gather_chunks (int): chunk size for async gathering data. + + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype=None, + gather_output=False, + gather_input=True, + async_gather_chunks=2, + ) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.gather_output = gather_output self.gather_input = gather_input self.async_gather_chunks = async_gather_chunks - tp_size = config['tp_size'] + tp_size = config["tp_size"] assert out_features % tp_size == 0 self.out_features_per_partition = out_features // tp_size - self.weight = bmt.DistributedParameter(torch.empty(self.out_features_per_partition, in_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.xavier_normal_, tp_split_dim=0, tp_mode=True) + self.weight = bmt.DistributedParameter( + torch.empty( + self.out_features_per_partition, in_features, dtype=dtype, device="cuda" + ), + init_method=torch.nn.init.xavier_normal_, + tp_split_dim=0, + tp_mode=True, + ) if bias: - self.bias = bmt.DistributedParameter(torch.empty(self.out_features_per_partition, dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_, tp_split_dim=0, tp_mode=True) + self.bias = bmt.DistributedParameter( + torch.empty( + self.out_features_per_partition, dtype=dtype, device="cuda" + ), + init_method=torch.nn.init.zeros_, + tp_split_dim=0, + tp_mode=True, + ) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) def forward(self, input): - gather_input = self.gather_input + gather_input = self.gather_input split_input = False - reduce_output_type = None - return OpParallelLinear.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type, self.async_gather_chunks) + reduce_output_type = None + return OpParallelLinear.apply( + input, + self.weight, + self.bias, + gather_input, + self.gather_output, + split_input, + reduce_output_type, + self.async_gather_chunks, + ) def extra_repr(self) -> str: - return 'in_features={}, out_features={}, bias={}'.format( + return "in_features={}, out_features={}, bias={}".format( self.in_features, self.out_features_per_partitions, self.bias is not None ) - diff --git a/bmtrain/nn/linear.py b/bmtrain/nn/linear.py index cb04863a..8afb1d89 100644 --- a/bmtrain/nn/linear.py +++ b/bmtrain/nn/linear.py @@ -2,6 +2,7 @@ import torch.nn.functional as F import bmtrain as bmt + class OpLinear(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, bias=None): @@ -16,28 +17,40 @@ def backward(ctx, grad_output): grad_x = grad_output.matmul(weight) if weight.requires_grad: dim = grad_output.dim() - grad_weight = grad_output.reshape(-1, - grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1])) + grad_weight = ( + grad_output.reshape(-1, grad_output.shape[-1]) + .t() + .matmul(x.reshape(-1, x.shape[-1])) + ) if bias is not None and bias.requires_grad: grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) return grad_x, grad_weight, grad_bias + class Linear(bmt.DistributedModule): - def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + def __init__( + self, in_features: int, out_features: int, bias: bool = True, dtype=None + ) -> None: super().__init__() self.in_features = in_features self.out_features = out_features - self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.xavier_normal_) + self.weight = bmt.DistributedParameter( + torch.empty(out_features, in_features, dtype=dtype, device="cuda"), + init_method=torch.nn.init.xavier_normal_, + ) if bias: - self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_) + self.bias = bmt.DistributedParameter( + torch.empty(out_features, dtype=dtype, device="cuda"), + init_method=torch.nn.init.zeros_, + ) else: - self.register_parameter('bias', None) - + self.register_parameter("bias", None) + def forward(self, input): return OpLinear.apply(input, self.weight, self.bias) def extra_repr(self) -> str: - return 'in_features={}, out_features={}, bias={}'.format( + return "in_features={}, out_features={}, bias={}".format( self.in_features, self.out_features, self.bias is not None ) diff --git a/bmtrain/nn/parallel_embedding.py b/bmtrain/nn/parallel_embedding.py index 43e7397d..3bdc4e56 100644 --- a/bmtrain/nn/parallel_embedding.py +++ b/bmtrain/nn/parallel_embedding.py @@ -6,10 +6,21 @@ import bmtrain as bmt from bmtrain.global_var import config from bmtrain.distributed import all_reduce, all_gather -from .parallel_linear_func import OpParallelLinear +from .parallel_linear_func import OpParallelLinear class VPEmbedding(bmt.DistributedModule): + """Vocab Parallel Embedding. + + Args: + vocab_size (int required): vocab size. + embedding_size (int required): embedding size. + dtype (torch.dtype): data type. + init_mean (float optional): mean for weight init. + init_std (float optional): std for weight init. + + """ + def __init__( self, vocab_size: int, @@ -27,16 +38,22 @@ def __init__( self.end_index = (bmt.config["tp_rank"] + 1) * self.vocab_size_per_partition self.weight = bmt.DistributedParameter( torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype), - init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), + init_method=bmt.ParameterInitializer( + torch.nn.init.normal_, mean=init_mean, std=init_std + ), tp_split_dim=0, tp_mode=True, ) def forward(self, x: torch.Tensor, projection=False): if not projection: - weight = all_gather(self.weight, comm=config['tp_comm']).flatten(0,1) + weight = all_gather(self.weight, comm=config["tp_comm"]).flatten(0, 1) out = F.embedding(x, weight) return out else: - x = bmt.distributed.all_gather(x, comm=bmt.config['tp_comm']).view(x.shape[0], -1, x.shape[-1]) - return bmt.nn.OpParallelLinear.apply(x, self.weight, None, False, False, False, None, 1) + x = bmt.distributed.all_gather(x, comm=bmt.config["tp_comm"]).view( + x.shape[0], -1, x.shape[-1] + ) + return bmt.nn.OpParallelLinear.apply( + x, self.weight, None, False, False, False, None, 1 + ) diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index e242f7ed..e389cde6 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -1,51 +1,54 @@ import torch import torch.nn.functional as F -from bmtrain.global_var import config +from bmtrain.global_var import config from ..distributed import all_gather, all_reduce from .. import nccl import bmtrain as bmt from enum import Enum + class ReduceType(Enum): ALL_REDUCE = 1 REDUCE_SCATTER = 2 + def preprocess_input(input, gather_input, split_input): if gather_input: - input = all_gather(input, config['tp_comm']) + input = all_gather(input, config["tp_comm"]) input = input.flatten(0, 1) if split_input: - all_input_list = input.chunk(config['tp_size'], dim=-1) - input = all_input_list[config['topology'].tp_id] + all_input_list = input.chunk(config["tp_size"], dim=-1) + input = all_input_list[config["topology"].tp_id] return input + def async_all_gather_linear_func(input, weight, bias, async_chunks=2): dim = input.dim() shape = list(input.shape) if dim > 2: input = input.view(-1, input.shape[-1]) - tp_size = config['tp_size'] + tp_size = config["tp_size"] current_stream = torch.cuda.current_stream() - comm_stream = config['tp_comm_stream'] + comm_stream = config["tp_comm_stream"] rounds = async_chunks inputs = input.chunk(rounds, dim=0) comm_stream.wait_stream(current_stream) outputs = [None] * tp_size * rounds - input = all_gather(inputs[0], config['tp_comm']) + input = all_gather(inputs[0], config["tp_comm"]) input = input.flatten(0, 1) out = F.linear(input, weight, bias) outs = out.chunk(tp_size, dim=0) for i in range(tp_size): outputs[i * rounds] = outs[i] - #async all_gather and overalap with linear - for i in range(rounds-1): + # async all_gather and overalap with linear + for i in range(rounds - 1): with torch.cuda.stream(comm_stream): - inputs[i+1].record_stream(comm_stream) - input = all_gather(inputs[i+1], config['tp_comm']) + inputs[i + 1].record_stream(comm_stream) + input = all_gather(inputs[i + 1], config["tp_comm"]) input = input.flatten(0, 1) current_stream.wait_stream(comm_stream) @@ -62,31 +65,34 @@ def async_all_gather_linear_func(input, weight, bias, async_chunks=2): out = out.view(shape) return out + def async_reduce_scatter_linear_func(input, weight, bias, async_chunks=2): - tp_size = config['tp_size'] - comm_stream = config['tp_comm_stream'] + tp_size = config["tp_size"] + comm_stream = config["tp_comm_stream"] rounds = async_chunks input_shape = list(input.shape) dim = input.dim() if dim > 2: input = input.view(-1, input.shape[-1]) - inputs = input.chunk(rounds*tp_size, dim=0) + inputs = input.chunk(rounds * tp_size, dim=0) current_stream = torch.cuda.current_stream() - outputs = [None] * rounds + outputs = [None] * rounds for i in range(rounds): input = [None] * tp_size for j in range(tp_size): - input[j] = inputs[j*rounds + i] + input[j] = inputs[j * rounds + i] input = torch.cat(input, dim=0) out = F.linear(input, weight, bias) with torch.cuda.stream(comm_stream): comm_stream.wait_stream(current_stream) out.record_stream(comm_stream) shape = list(out.shape) - shape[0] = shape[0] // config['tp_size'] + shape[0] = shape[0] // config["tp_size"] outputs[i] = torch.empty(shape, dtype=out.dtype, device=out.device) - nccl.reduceScatter(out.storage(), outputs[i].storage(), "sum", config['tp_comm']) + nccl.reduceScatter( + out.storage(), outputs[i].storage(), "sum", config["tp_comm"] + ) current_stream.wait_stream(comm_stream) out = torch.cat(outputs, dim=0) @@ -98,10 +104,13 @@ def async_reduce_scatter_linear_func(input, weight, bias, async_chunks=2): return out -def async_all_gather_linear_backward_func(grad_out, input, weight, bias, async_chunks=2): - tp_size = config['tp_size'] + +def async_all_gather_linear_backward_func( + grad_out, input, weight, bias, async_chunks=2 +): + tp_size = config["tp_size"] current_stream = torch.cuda.current_stream() - comm_stream = config['tp_comm_stream'] + comm_stream = config["tp_comm_stream"] input_require_grad = input.requires_grad dim = input.dim() input_shape = input.shape @@ -110,8 +119,8 @@ def async_all_gather_linear_backward_func(grad_out, input, weight, bias, async_c grad_out = grad_out.view(-1, grad_out.shape[-1]) rounds = async_chunks - grad_inputs = [None] * tp_size * rounds - grad_weights = [None] * tp_size * rounds + grad_inputs = [None] * tp_size * rounds + grad_weights = [None] * tp_size * rounds grad_outs = [None] * tp_size * rounds local_grad_outs = grad_out.chunk(rounds, dim=0) @@ -135,40 +144,50 @@ def async_all_gather_linear_backward_func(grad_out, input, weight, bias, async_c grad_input = grad_weight = grad_bias = None - grad_out = all_gather(local_grad_outs[0], config['tp_comm']) + grad_out = all_gather(local_grad_outs[0], config["tp_comm"]) for j in range(tp_size): grad_outs[j * rounds] = grad_out[j] - grad_out = grad_out.flatten(0, 1) # (tp_size * (m/rounds), n) + grad_out = grad_out.flatten(0, 1) # (tp_size * (m/rounds), n) if input_require_grad: - grad_input = grad_out.matmul(weight) # (tp_size * (m/rounds), n) * (n, k/tp_size) + grad_input = grad_out.matmul( + weight + ) # (tp_size * (m/rounds), n) * (n, k/tp_size) tmp_grad_inputs = grad_input.chunk(tp_size, dim=0) for j in range(tp_size): grad_inputs[j * rounds] = tmp_grad_inputs[j] if weight.requires_grad: - grad_weight = grad_out.reshape(-1, - grad_out.shape[-1]).t().matmul(inputs[0].reshape(-1, inputs[0].shape[-1])) - - #async all_gather and overalap with matmul - for i in range(rounds-1): + grad_weight = ( + grad_out.reshape(-1, grad_out.shape[-1]) + .t() + .matmul(inputs[0].reshape(-1, inputs[0].shape[-1])) + ) + + # async all_gather and overalap with matmul + for i in range(rounds - 1): with torch.cuda.stream(comm_stream): - local_grad_outs[i+1].record_stream(comm_stream) - grad_out = all_gather(local_grad_outs[i+1], config['tp_comm']) + local_grad_outs[i + 1].record_stream(comm_stream) + grad_out = all_gather(local_grad_outs[i + 1], config["tp_comm"]) for j in range(tp_size): - grad_outs[j * rounds + i+1] = grad_out[j] - grad_out = grad_out.flatten(0, 1) # (tp_size * (m/rounds), n) + grad_outs[j * rounds + i + 1] = grad_out[j] + grad_out = grad_out.flatten(0, 1) # (tp_size * (m/rounds), n) current_stream.wait_stream(comm_stream) if input_require_grad: - grad_input = grad_out.matmul(weight) # (tp_size * (m/rounds), n) * (n, k/tp_size) + grad_input = grad_out.matmul( + weight + ) # (tp_size * (m/rounds), n) * (n, k/tp_size) tmp_grad_inputs = grad_input.chunk(tp_size, dim=0) for j in range(tp_size): - grad_inputs[j * rounds + i+1] = tmp_grad_inputs[j] + grad_inputs[j * rounds + i + 1] = tmp_grad_inputs[j] if weight.requires_grad: dim = grad_out.dim() - grad_weight += grad_out.reshape(-1, - grad_out.shape[-1]).t().matmul(inputs[i+1].reshape(-1, inputs[i+1].shape[-1])) + grad_weight += ( + grad_out.reshape(-1, grad_out.shape[-1]) + .t() + .matmul(inputs[i + 1].reshape(-1, inputs[i + 1].shape[-1])) + ) if input_require_grad: grad_input = torch.cat(grad_inputs, dim=0) @@ -180,9 +199,25 @@ def async_all_gather_linear_backward_func(grad_out, input, weight, bias, async_c return grad_input, grad_weight, grad_bias + class OpParallelLinear(torch.autograd.Function): + """OpParallelLinear is a subclass of torch.autograd.Function. + It gathers the input tensor when needed, and all reduce or reduece scatter the output when needed. + + """ + @staticmethod - def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=False, split_input=False, reduce_output_type=None, async_gather_chunks=2): + def forward( + ctx, + input, + weight, + bias=None, + gather_input=False, + gather_output=False, + split_input=False, + reduce_output_type=None, + async_gather_chunks=2, + ): if reduce_output_type is not None: reduce_output_type = ReduceType(reduce_output_type) @@ -193,28 +228,37 @@ def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=Fal ctx.reduce_output_type = reduce_output_type ctx.async_gather_chunks = async_gather_chunks - if gather_input and config['tp_size'] > 1 and async_gather_chunks > 1 and split_input == False: + if ( + gather_input + and config["tp_size"] > 1 + and async_gather_chunks > 1 + and split_input == False + ): out = async_all_gather_linear_func(input, weight, bias, async_gather_chunks) elif reduce_output_type == ReduceType.REDUCE_SCATTER: - return async_reduce_scatter_linear_func(input, weight, bias, async_gather_chunks) + return async_reduce_scatter_linear_func( + input, weight, bias, async_gather_chunks + ) else: all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) out = F.linear(all_input, weight, bias) if gather_output: - all_output_list = all_gather(out, config['tp_comm']) - all_output_list = all_output_list.chunk(config['tp_size'], dim=0) - out = torch.cat(all_output_list, dim=all_output_list[0].dim()-1).flatten(0,1) + all_output_list = all_gather(out, config["tp_comm"]) + all_output_list = all_output_list.chunk(config["tp_size"], dim=0) + out = torch.cat(all_output_list, dim=all_output_list[0].dim() - 1).flatten( + 0, 1 + ) if reduce_output_type is None: return out if reduce_output_type == ReduceType.ALL_REDUCE: - nccl.allReduce(out.storage(), out.storage(), "sum", config['tp_comm']) - return out + nccl.allReduce(out.storage(), out.storage(), "sum", config["tp_comm"]) + return out else: assert False, "no support reduce type{}".format(reduce_output_type) - + @staticmethod def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors @@ -222,15 +266,19 @@ def backward(ctx, grad_output): if ctx.reduce_output_type == ReduceType.REDUCE_SCATTER: if input.requires_grad or weight.requires_grad: - grad_input, grad_weight, grad_bias = async_all_gather_linear_backward_func(grad_output, input, weight, bias, ctx.async_gather_chunks) - return grad_input, grad_weight, grad_bias, None, None, None, None, None + grad_input, grad_weight, grad_bias = ( + async_all_gather_linear_backward_func( + grad_output, input, weight, bias, ctx.async_gather_chunks + ) + ) + return grad_input, grad_weight, grad_bias, None, None, None, None, None else: - grad_output = all_gather(grad_output, config['tp_comm']) + grad_output = all_gather(grad_output, config["tp_comm"]) grad_output = grad_output.flatten(0, 1) if gather_output: - tp_size = config['tp_size'] - tp_id = config['topology'].tp_id + tp_size = config["tp_size"] + tp_id = config["topology"].tp_id grad_output_list = grad_output.chunk(tp_size, dim=-1) grad_output = grad_output_list[tp_id] @@ -240,12 +288,14 @@ def backward(ctx, grad_output): if input.requires_grad or weight.requires_grad: if ctx.gather_input: # async the all_gather - with torch.cuda.stream(config['tp_comm_stream']): - input.record_stream(config['tp_comm_stream']) - config['tp_comm_stream'].wait_stream(current_stream) - all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) - #use event to solve two streams waiting for each other - gather_event = config['tp_comm_stream'].record_event() + with torch.cuda.stream(config["tp_comm_stream"]): + input.record_stream(config["tp_comm_stream"]) + config["tp_comm_stream"].wait_stream(current_stream) + all_input = preprocess_input( + input, ctx.gather_input, ctx.split_input + ) + # use event to solve two streams waiting for each other + gather_event = config["tp_comm_stream"].record_event() else: all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) @@ -254,36 +304,49 @@ def backward(ctx, grad_output): grad_input = torch.zeros_like(input) if ctx.gather_input: # async the reduce_scatter - with torch.cuda.stream(config['tp_comm_stream']): - config['tp_comm_stream'].wait_stream(current_stream) - grad_input.record_stream(config['tp_comm_stream']) - grad_all_input.record_stream(config['tp_comm_stream']) - nccl.reduceScatter(grad_all_input.storage(), grad_input.storage(), "sum", config['tp_comm']) + with torch.cuda.stream(config["tp_comm_stream"]): + config["tp_comm_stream"].wait_stream(current_stream) + grad_input.record_stream(config["tp_comm_stream"]) + grad_all_input.record_stream(config["tp_comm_stream"]) + nccl.reduceScatter( + grad_all_input.storage(), + grad_input.storage(), + "sum", + config["tp_comm"], + ) elif ctx.reduce_output_type is None: - with torch.cuda.stream(config['tp_comm_stream']): - config['tp_comm_stream'].wait_stream(current_stream) - grad_input.record_stream(config['tp_comm_stream']) - nccl.allReduce(grad_all_input.storage(), grad_all_input.storage(), "sum", config['tp_comm']) + with torch.cuda.stream(config["tp_comm_stream"]): + config["tp_comm_stream"].wait_stream(current_stream) + grad_input.record_stream(config["tp_comm_stream"]) + nccl.allReduce( + grad_all_input.storage(), + grad_all_input.storage(), + "sum", + config["tp_comm"], + ) grad_input = grad_all_input else: grad_input = grad_all_input if ctx.split_input: - with torch.cuda.stream(config['tp_comm_stream']): - config['tp_comm_stream'].wait_stream(current_stream) - grad_input.record_stream(config['tp_comm_stream']) - grad_input = all_gather(grad_input, config['tp_comm']) + with torch.cuda.stream(config["tp_comm_stream"]): + config["tp_comm_stream"].wait_stream(current_stream) + grad_input.record_stream(config["tp_comm_stream"]) + grad_input = all_gather(grad_input, config["tp_comm"]) - # wait all_gather + # wait all_gather if ctx.gather_input: current_stream.wait_event(gather_event) if weight.requires_grad: - grad_weight = grad_output.reshape(-1, - grad_output.shape[-1]).t().matmul(all_input.reshape(-1, all_input.shape[-1])) - + grad_weight = ( + grad_output.reshape(-1, grad_output.shape[-1]) + .t() + .matmul(all_input.reshape(-1, all_input.shape[-1])) + ) + if bias is not None and bias.requires_grad: grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) current_stream = torch.cuda.current_stream() - current_stream.wait_stream(config['tp_comm_stream']) + current_stream.wait_stream(config["tp_comm_stream"]) return grad_input, grad_weight, grad_bias, None, None, None, None, None diff --git a/bmtrain/nn/row_parallel_linear.py b/bmtrain/nn/row_parallel_linear.py index 7451e7d3..ee4610cc 100644 --- a/bmtrain/nn/row_parallel_linear.py +++ b/bmtrain/nn/row_parallel_linear.py @@ -3,37 +3,86 @@ import bmtrain as bmt from bmtrain.global_var import config -from .parallel_linear_func import ( - OpParallelLinear, - ReduceType) +from .parallel_linear_func import OpParallelLinear, ReduceType + class RowParallelLinear(bmt.DistributedModule): - def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, split_input=False, all_reduce_output=False, async_chunks=2) -> None: + """Tensor Parallel use row partition for Linear. + + Args: + in_features (int): in_features size. + out_features (int): out_features size. + bias (bool): whether use bias. + dtype : data type. + split_input (bool): whether split input before compute. + all_reduce_output (bool): if true use all_reduce data after compute, or use reduce_scatter. + async_chunks (int): chunk size for async. + + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype=None, + split_input=False, + all_reduce_output=False, + async_chunks=2, + ) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.split_input = split_input self.all_reduce_output = all_reduce_output self.async_chunks = async_chunks - tp_size = config['tp_size'] + tp_size = config["tp_size"] assert in_features % tp_size == 0 self.in_features_per_partition = in_features // tp_size - self.weight = bmt.DistributedParameter(torch.empty(self.out_features, self.in_features_per_partition, dtype=dtype, device="cuda"), init_method=torch.nn.init.xavier_normal_, tp_split_dim=1, tp_mode=True) + self.weight = bmt.DistributedParameter( + torch.empty( + self.out_features, + self.in_features_per_partition, + dtype=dtype, + device="cuda", + ), + init_method=torch.nn.init.xavier_normal_, + tp_split_dim=1, + tp_mode=True, + ) if bias: - self.bias = bmt.DistributedParameter(torch.empty(self.out_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_, tp_split_dim=-1, tp_mode=True) + self.bias = bmt.DistributedParameter( + torch.empty(self.out_features, dtype=dtype, device="cuda"), + init_method=torch.nn.init.zeros_, + tp_split_dim=-1, + tp_mode=True, + ) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) def forward(self, input): gather_input = self.split_input gather_output = False - reduce_output_type = ReduceType.ALL_REDUCE if self.all_reduce_output else ReduceType.REDUCE_SCATTER - out = OpParallelLinear.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type, self.async_chunks) + reduce_output_type = ( + ReduceType.ALL_REDUCE + if self.all_reduce_output + else ReduceType.REDUCE_SCATTER + ) + out = OpParallelLinear.apply( + input, + self.weight, + None, + gather_input, + gather_output, + self.split_input, + reduce_output_type, + self.async_chunks, + ) if self.bias is not None: out = out + self.bias return out def extra_repr(self) -> str: - return 'in_features={}, out_features={}, bias={}'.format( + return "in_features={}, out_features={}, bias={}".format( self.in_features_per_partition, self.out_features, self.bias is not None ) diff --git a/bmtrain/optim/_distributed.py b/bmtrain/optim/_distributed.py index 11daa2b0..df8f2f3e 100644 --- a/bmtrain/optim/_distributed.py +++ b/bmtrain/optim/_distributed.py @@ -1,29 +1,40 @@ import torch from ..distributed import all_reduce, all_gather + def state_dict_gather(state_dict): - param_key = [p for param_group in state_dict['param_groups'] for p in param_group['params'] ] - for k, v in state_dict['state'].items(): + param_key = [ + p for param_group in state_dict["param_groups"] for p in param_group["params"] + ] + for k, v in state_dict["state"].items(): if "step" in v: - step = v['step'] + step = v["step"] for k in param_key: - if k not in state_dict['state']: - state_dict['state'][k] = { - 'exp_avg' : torch.tensor([], device="cuda", dtype=torch.float32), - 'exp_avg_sq' : torch.tensor([], device="cuda", dtype=torch.float32), - '_param_fp32' : torch.tensor([], device="cuda", dtype=torch.float32), - 'step' : step + if k not in state_dict["state"]: + state_dict["state"][k] = { + "exp_avg": torch.tensor([], device="cuda", dtype=torch.float32), + "exp_avg_sq": torch.tensor([], device="cuda", dtype=torch.float32), + "_param_fp32": torch.tensor([], device="cuda", dtype=torch.float32), + "step": step, } - v = state_dict['state'][k] - for name, dtype in [("exp_avg", torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]: + v = state_dict["state"][k] + for name, dtype in [ + ("exp_avg", torch.float32), + ("exp_avg_sq", torch.float32), + ("_param_fp32", torch.float32), + ]: if name in v: with torch.no_grad(): - numel = torch.tensor(v[name].numel(), device="cuda", dtype=torch.long) + numel = torch.tensor( + v[name].numel(), device="cuda", dtype=torch.long + ) max_numel = all_reduce(numel, op="max") - v_p = torch.nn.functional.pad(v[name], (0, max_numel - numel), value=-1e15) + v_p = torch.nn.functional.pad( + v[name], (0, max_numel - numel), value=-1e15 + ) if max_numel > 0: whole_state = all_gather(v_p.cuda()).flatten() whole_state = whole_state[whole_state != -1e15] v[name] = whole_state.contiguous().cpu() - return state_dict \ No newline at end of file + return state_dict diff --git a/bmtrain/optim/_function.py b/bmtrain/optim/_function.py index d4584457..f9e0ce9d 100644 --- a/bmtrain/optim/_function.py +++ b/bmtrain/optim/_function.py @@ -1,28 +1,52 @@ from .. import C import torch + CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda + def bf16_from_fp32(param_fp32): param_bf16 = torch.empty_like(param_fp32, dtype=torch.bfloat16) - C.to_bf16_from_fp32(param_fp32.numel(), param_fp32.data_ptr(), param_bf16.data_ptr()) + C.to_bf16_from_fp32( + param_fp32.numel(), param_fp32.data_ptr(), param_bf16.data_ptr() + ) return param_bf16 + def fp16_from_fp32(param_fp32): param_fp16 = torch.empty_like(param_fp32, dtype=torch.float16) - C.to_fp16_from_fp32(param_fp32.numel(), param_fp32.data_ptr(), param_fp16.data_ptr()) + C.to_fp16_from_fp32( + param_fp32.numel(), param_fp32.data_ptr(), param_fp16.data_ptr() + ) return param_fp16 - -def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, delta_info: torch.Tensor, g_fp16: torch.Tensor, m_fp32: torch.Tensor, - v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float, - weight_decay: float, step: int) -> None: + + +def adam_cpu( + param_fp32: torch.Tensor, + param_fp16: torch.Tensor, + delta_info: torch.Tensor, + g_fp16: torch.Tensor, + m_fp32: torch.Tensor, + v_fp32: torch.Tensor, + beta1: float, + beta2: float, + eps: float, + lr: float, + scale: float, + weight_decay: float, + step: int, +) -> None: assert param_fp32.is_contiguous(), "param_fp32 must be contiguous" assert param_fp16.is_contiguous(), "param_fp16 must be contiguous" assert g_fp16.is_contiguous(), "g_fp16 must be contiguous" assert m_fp32.is_contiguous(), "m_fp32 must be contiguous" assert v_fp32.is_contiguous(), "v_fp32 must be contiguous" assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor" - assert param_fp16.dtype == torch.float16 or param_fp16.dtype == torch.bfloat16, "param_fp16 must be float16/bfloat16 tensor" - assert g_fp16.dtype == torch.float16 or g_fp16.dtype == torch.bfloat16, "g_fp16 must be float16/bfloat16 tensor" + assert ( + param_fp16.dtype == torch.float16 or param_fp16.dtype == torch.bfloat16 + ), "param_fp16 must be float16/bfloat16 tensor" + assert ( + g_fp16.dtype == torch.float16 or g_fp16.dtype == torch.bfloat16 + ), "g_fp16 must be float16/bfloat16 tensor" assert m_fp32.dtype == torch.float32, "m_fp32 must be float32 tensor" assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor" assert param_fp32.device == torch.device("cpu"), "param_fp32 must be a cpu tensor" @@ -30,17 +54,27 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, delta_info: tor assert g_fp16.device == torch.device("cpu"), "g_fp16 must be a cpu tensor" assert m_fp32.device == torch.device("cpu"), "m_fp32 must be a cpu tensor" assert v_fp32.device == torch.device("cpu"), "v_fp32 must be a cpu tensor" - assert param_fp32.numel() == param_fp16.numel(), "param_fp32 and param_fp16 must have the same number of elements" - assert param_fp32.numel() == g_fp16.numel(), "param_fp32 and g_fp16 must have the same number of elements" - assert param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_fp32 must have the same number of elements" - assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements" + assert ( + param_fp32.numel() == param_fp16.numel() + ), "param_fp32 and param_fp16 must have the same number of elements" + assert ( + param_fp32.numel() == g_fp16.numel() + ), "param_fp32 and g_fp16 must have the same number of elements" + assert ( + param_fp32.numel() == m_fp32.numel() + ), "param_fp32 and m_fp32 must have the same number of elements" + assert ( + param_fp32.numel() == v_fp32.numel() + ), "param_fp32 and v_fp32 must have the same number of elements" if delta_info is not None: assert delta_info.is_contiguous(), "delta_info must be contiguous" assert delta_info.dtype == torch.float32, "delta_info must be float32 tensor" - assert delta_info.device == torch.device("cpu"), "delta_info must be a cpu tensor" + assert delta_info.device == torch.device( + "cpu" + ), "delta_info must be a cpu tensor" assert delta_info.numel() == 4, "delta_info have a length of 4" - bias_correction1 = 1 - beta1 ** step - bias_correction2 = 1 - beta2 ** step + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step if g_fp16.dtype == torch.float16: launcher = C.adam_cpu_fp16_launcher elif g_fp16.dtype == torch.bfloat16: @@ -55,17 +89,31 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, delta_info: tor g_fp16.data_ptr(), m_fp32.data_ptr(), v_fp32.data_ptr(), - beta1, beta2, - eps, lr, + beta1, + beta2, + eps, + lr, scale, weight_decay, bias_correction1, bias_correction2, ) -def adam_fp16(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tensor, m_fp16: torch.Tensor, - v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float, - weight_decay: float, step: int) -> None: + +def adam_fp16( + param_fp32: torch.Tensor, + param_fp16: torch.Tensor, + g_fp16: torch.Tensor, + m_fp16: torch.Tensor, + v_fp32: torch.Tensor, + beta1: float, + beta2: float, + eps: float, + lr: float, + scale: float, + weight_decay: float, + step: int, +) -> None: assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda" assert CHECK_INPUT(param_fp16), "param_fp16 must be contiguous and on cuda" assert CHECK_INPUT(g_fp16), "g_fp16 must be contiguous and on cuda" @@ -76,12 +124,20 @@ def adam_fp16(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch. assert g_fp16.dtype == torch.float16, "g_fp16 must be float16 tensor" assert m_fp16.dtype == torch.float16, "m_fp16 must be float16 tensor" assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor" - assert param_fp32.numel() == param_fp16.numel(), "param_fp32 and param_fp16 must have the same number of elements" - assert param_fp32.numel() == g_fp16.numel(), "param_fp32 and g_fp16 must have the same number of elements" - assert param_fp32.numel() == m_fp16.numel(), "param_fp32 and m_fp32 must have the same number of elements" - assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements" - bias_correction1 = 1 - beta1 ** step - bias_correction2 = 1 - beta2 ** step + assert ( + param_fp32.numel() == param_fp16.numel() + ), "param_fp32 and param_fp16 must have the same number of elements" + assert ( + param_fp32.numel() == g_fp16.numel() + ), "param_fp32 and g_fp16 must have the same number of elements" + assert ( + param_fp32.numel() == m_fp16.numel() + ), "param_fp32 and m_fp32 must have the same number of elements" + assert ( + param_fp32.numel() == v_fp32.numel() + ), "param_fp32 and v_fp32 must have the same number of elements" + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step stream = torch.cuda.current_stream().cuda_stream C.adam_fp16_launcher( param_fp32.numel(), @@ -90,18 +146,32 @@ def adam_fp16(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch. g_fp16.data_ptr(), m_fp16.data_ptr(), v_fp32.data_ptr(), - beta1, beta2, - eps, lr, + beta1, + beta2, + eps, + lr, scale, weight_decay, bias_correction1, bias_correction2, - stream + stream, ) - -def adam_bf16(param_fp32: torch.Tensor, param_bf16: torch.Tensor, g_bf16: torch.Tensor, m_fp32: torch.Tensor, - v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float, - weight_decay: float, step: int) -> None: + + +def adam_bf16( + param_fp32: torch.Tensor, + param_bf16: torch.Tensor, + g_bf16: torch.Tensor, + m_fp32: torch.Tensor, + v_fp32: torch.Tensor, + beta1: float, + beta2: float, + eps: float, + lr: float, + scale: float, + weight_decay: float, + step: int, +) -> None: assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda" assert CHECK_INPUT(param_bf16), "param_bf16 must be contiguous and on cuda" assert CHECK_INPUT(g_bf16), "g_bf16 must be contiguous and on cuda" @@ -112,12 +182,20 @@ def adam_bf16(param_fp32: torch.Tensor, param_bf16: torch.Tensor, g_bf16: torch. assert g_bf16.dtype == torch.bfloat16, "g_bf16 must be bfloat16 tensor" assert m_fp32.dtype == torch.float32, "m_fp32 must be bfloat16 tensor" assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor" - assert param_fp32.numel() == param_bf16.numel(), "param_fp32 and param_bf16 must have the same number of elements" - assert param_fp32.numel() == g_bf16.numel(), "param_fp32 and g_fp16 must have the same number of elements" - assert param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_m_fp32 must have the same number of elements" - assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements" - bias_correction1 = 1 - beta1 ** step - bias_correction2 = 1 - beta2 ** step + assert ( + param_fp32.numel() == param_bf16.numel() + ), "param_fp32 and param_bf16 must have the same number of elements" + assert ( + param_fp32.numel() == g_bf16.numel() + ), "param_fp32 and g_fp16 must have the same number of elements" + assert ( + param_fp32.numel() == m_fp32.numel() + ), "param_fp32 and m_m_fp32 must have the same number of elements" + assert ( + param_fp32.numel() == v_fp32.numel() + ), "param_fp32 and v_fp32 must have the same number of elements" + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step stream = torch.cuda.current_stream().cuda_stream if not C.is_bf16_supported(): raise NotImplementedError(f"bfloat16 is not supported on current GPU") @@ -128,11 +206,13 @@ def adam_bf16(param_fp32: torch.Tensor, param_bf16: torch.Tensor, g_bf16: torch. g_bf16.data_ptr(), m_fp32.data_ptr(), v_fp32.data_ptr(), - beta1, beta2, - eps, lr, + beta1, + beta2, + eps, + lr, scale, weight_decay, bias_correction1, bias_correction2, - stream - ) \ No newline at end of file + stream, + ) diff --git a/bmtrain/optim/adam.py b/bmtrain/optim/adam.py index d412b80e..f99c483c 100644 --- a/bmtrain/optim/adam.py +++ b/bmtrain/optim/adam.py @@ -2,7 +2,7 @@ from ..global_var import config from . import _function as F import torch.optim._functional -from .. import C +from .. import C from .. import nccl import inspect from ..utils import check_torch_version @@ -10,13 +10,23 @@ from itertools import chain from collections import defaultdict + class AdamOptimizer(torch.optim.Optimizer): """ - Adam optimizer + Adam optimizer support fp16 and bf16. """ + _bmtrain_optimizer = True - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, hold_steps=0): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + hold_steps=0, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -36,13 +46,13 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0 def _on_justify_scale(self, old_scale, new_scale): delta = new_scale / old_scale for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p in self.state: state = self.state[p] if len(state) > 0: if p.dtype == torch.float16: - state['exp_avg'] *= delta - state['exp_avg_sq'] *= delta + state["exp_avg"] *= delta + state["exp_avg_sq"] *= delta @torch.no_grad() def step(self, closure=None, scale=1): @@ -60,84 +70,107 @@ def step(self, closure=None, scale=1): # update parameters for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is not None and p.requires_grad: if p.grad.is_sparse: - raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) if p.dtype not in [torch.float32, torch.half, torch.bfloat16]: - raise RuntimeError('Adam only supports fp32, fp16 and bf16 gradients') + raise RuntimeError( + "Adam only supports fp32, fp16 and bf16 gradients" + ) state = self.state[p] # Lazy state initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values if p.dtype == torch.float16: - state['exp_avg'] = torch.zeros(p.size(), dtype=torch.float16, device=p.device) # on device + state["exp_avg"] = torch.zeros( + p.size(), dtype=torch.float16, device=p.device + ) # on device else: - state['exp_avg'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device) # on device + state["exp_avg"] = torch.zeros( + p.size(), dtype=torch.float32, device=p.device + ) # on device # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device)# on device - + state["exp_avg_sq"] = torch.zeros( + p.size(), dtype=torch.float32, device=p.device + ) # on device + if p.dtype != torch.float32: - state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device=p.device) # on device - state['_param_fp32'].copy_(p) + state["_param_fp32"] = torch.empty( + p.size(), dtype=torch.float32, device=p.device + ) # on device + state["_param_fp32"].copy_(p) # update the steps for each param group update - if ('maximize' in group) and (group['maximize'] is True): + if ("maximize" in group) and (group["maximize"] is True): grad = -p.grad else: grad = p.grad - + if p.dtype == torch.float32: other_kwargs = {} - if 'maximize' in inspect.signature(torch.optim._functional.adam).parameters: - other_kwargs['maximize'] = False + if ( + "maximize" + in inspect.signature( + torch.optim._functional.adam + ).parameters + ): + other_kwargs["maximize"] = False torch.optim._functional.adam( [p], [grad / scale], - [state['exp_avg']], + [state["exp_avg"]], [state["exp_avg_sq"]], [], - [state["step"]] if check_torch_version("1.12.0") < 0 - else [torch.tensor(state["step"])], + ( + [state["step"]] + if check_torch_version("1.12.0") < 0 + else [torch.tensor(state["step"])] + ), amsgrad=False, - beta1=group['betas'][0], - beta2=group['betas'][1], - lr=0.0 if state["step"] < self._hold_steps else group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps'], + beta1=group["betas"][0], + beta2=group["betas"][1], + lr=0.0 if state["step"] < self._hold_steps else group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], **other_kwargs ) - state['step'] += 1 + state["step"] += 1 else: f = F.adam_fp16 if p.dtype == torch.float16 else F.adam_bf16 - state['step'] += 1 + state["step"] += 1 f( - state["_param_fp32"], # fp32 - p, # fp16 - grad, # fp16 - state['exp_avg'], # fp16: m - state["exp_avg_sq"], # fp32: v - group['betas'][0], group['betas'][1], - group['eps'], - 0.0 if state["step"] < self._hold_steps else group['lr'], + state["_param_fp32"], # fp32 + p, # fp16 + grad, # fp16 + state["exp_avg"], # fp16: m + state["exp_avg_sq"], # fp32: v + group["betas"][0], + group["betas"][1], + group["eps"], + 0.0 if state["step"] < self._hold_steps else group["lr"], scale, - group['weight_decay'], - state['step'] + group["weight_decay"], + state["step"], ) - - return loss def get_avg_delta(): - raise NotImplementedError("get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer") + raise NotImplementedError( + "get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer" + ) def get_var_delta(): - raise NotImplementedError("get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer") + raise NotImplementedError( + "get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer" + ) def load_state_dict(self, state_dict: dict) -> None: r"""Loads the optimizer state. @@ -150,35 +183,55 @@ def load_state_dict(self, state_dict: dict) -> None: state_dict = deepcopy(state_dict) # Validate the state_dict groups = self.param_groups - saved_groups = state_dict['param_groups'] + saved_groups = state_dict["param_groups"] if len(groups) != len(saved_groups): - raise ValueError("loaded state dict has a different number of " - "parameter groups") - param_lens = (len(g['params']) for g in groups) - saved_lens = (len(g['params']) for g in saved_groups) + raise ValueError( + "loaded state dict has a different number of " "parameter groups" + ) + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): - raise ValueError("loaded state dict contains a parameter group " - "that doesn't match the size of optimizer's group") + raise ValueError( + "loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group" + ) # Update the state - id_map = {old_id: p for old_id, p in - zip(chain.from_iterable((g['params'] for g in saved_groups)), - chain.from_iterable((g['params'] for g in groups)))} + id_map = { + old_id: p + for old_id, p in zip( + chain.from_iterable((g["params"] for g in saved_groups)), + chain.from_iterable((g["params"] for g in groups)), + ) + } # Copy state assigned to params (and cast tensors to appropriate types). # State that is not assigned to params is copied as is (needed for # backward compatibility). state = defaultdict(dict) - for k, v in state_dict['state'].items(): + for k, v in state_dict["state"].items(): if k in id_map: param = id_map[k] if param.dtype != torch.float32 and "_param_fp32" not in v: - v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device=param.device) + v["_param_fp32"] = torch.empty( + param.size(), dtype=torch.float32, device=param.device + ) v["_param_fp32"].copy_(param) - for name, dtype in [("exp_avg", torch.float16 if param.dtype == torch.float16 else torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]: + for name, dtype in [ + ( + "exp_avg", + ( + torch.float16 + if param.dtype == torch.float16 + else torch.float32 + ), + ), + ("exp_avg_sq", torch.float32), + ("_param_fp32", torch.float32), + ]: if name in v: v[name] = v[name].to(param.device).to(dtype) @@ -188,12 +241,12 @@ def load_state_dict(self, state_dict: dict) -> None: # Update parameter groups, setting their 'params' value def update_group(group, new_group): - new_group['params'] = group['params'] + new_group["params"] = group["params"] return new_group - param_groups = [ - update_group(g, ng) for g, ng in zip(groups, saved_groups)] - self.__setstate__({'state': state, 'param_groups': param_groups}) - #TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) + + # TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu def zero_grad(self, set_to_none: bool = False): - super().zero_grad(set_to_none=set_to_none) \ No newline at end of file + super().zero_grad(set_to_none=set_to_none) diff --git a/bmtrain/optim/adam_offload.py b/bmtrain/optim/adam_offload.py index c088a5ee..f6ea97ba 100644 --- a/bmtrain/optim/adam_offload.py +++ b/bmtrain/optim/adam_offload.py @@ -9,13 +9,24 @@ from collections import defaultdict from ._distributed import state_dict_gather + class AdamOffloadOptimizer(torch.optim.Optimizer): """ - Adam optimizer + Adam optimizer using optimizer offload. """ + _bmtrain_optimizer = True - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, hold_steps=0, record_delta=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + hold_steps=0, + record_delta=False, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -35,8 +46,16 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0 self.record_delta = record_delta if self.record_delta: for group in self.param_groups: - for p in group['params']: - setattr(p, "_delta_info", ( torch.tensor([0 for i in range(4)], dtype=torch.float32, device="cpu") )) + for p in group["params"]: + setattr( + p, + "_delta_info", + ( + torch.tensor( + [0 for i in range(4)], dtype=torch.float32, device="cpu" + ) + ), + ) @torch.no_grad() def step(self, closure=None, scale=1): @@ -56,40 +75,69 @@ def step(self, closure=None, scale=1): update_params = [] for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is not None and p.requires_grad: if p.grad.is_sparse: - raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) if p.dtype not in [torch.float32, torch.float16, torch.bfloat16]: - raise RuntimeError('Adam only supports fp32, fp16 and bf16 gradients') + raise RuntimeError( + "Adam only supports fp32, fp16 and bf16 gradients" + ) state = self.state[p] # Lazy state initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros(p.size(), dtype=torch.float32, device="cpu") # on host + state["exp_avg"] = torch.zeros( + p.size(), dtype=torch.float32, device="cpu" + ) # on host # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device="cpu") # on host + state["exp_avg_sq"] = torch.zeros( + p.size(), dtype=torch.float32, device="cpu" + ) # on host if p.dtype == torch.float32: - state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, pin_memory=True) # on host - state['_param_fp32'].copy_(p) + state["_param_fp32"] = torch.empty( + p.size(), dtype=torch.float32, pin_memory=True + ) # on host + state["_param_fp32"].copy_(p) # placeholder - state["_grad_fp32"] = torch.empty(p.size(), dtype=torch.float32, pin_memory=True) # on host + state["_grad_fp32"] = torch.empty( + p.size(), dtype=torch.float32, pin_memory=True + ) # on host else: - state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device="cpu") # on host - state['_param_fp32'].copy_(p) + state["_param_fp32"] = torch.empty( + p.size(), dtype=torch.float32, device="cpu" + ) # on host + state["_param_fp32"].copy_(p) # placeholder - state["_param_fp16"] = torch.empty(p.size(), dtype=p.dtype, pin_memory=True) # on host - state["_grad_fp16"] = torch.empty(p.size(), dtype=p.dtype, pin_memory=True) # on host + state["_param_fp16"] = torch.empty( + p.size(), dtype=p.dtype, pin_memory=True + ) # on host + state["_grad_fp16"] = torch.empty( + p.size(), dtype=p.dtype, pin_memory=True + ) # on host if p not in self._events: self._events[p] = torch.cuda.Event() - update_params.append((p, state, self._events[p], group['betas'][0], group['betas'][1], group['eps'], group['lr'], group['weight_decay'])) + update_params.append( + ( + p, + state, + self._events[p], + group["betas"][0], + group["betas"][1], + group["eps"], + group["lr"], + group["weight_decay"], + ) + ) # transfer parameters to host asynchronously for param, state, event, _, _, _, _, _ in update_params: @@ -108,21 +156,27 @@ def step(self, closure=None, scale=1): # update parameters if param.dtype == torch.float32: state["_grad_fp32"].mul_(1.0 / scale) - if ('maximize' in group) and (group['maximize'] is True): + if ("maximize" in group) and (group["maximize"] is True): grad = -state["_grad_fp32"] else: grad = state["_grad_fp32"] other_kwargs = {} - if 'maximize' in inspect.signature(torch.optim._functional.adam).parameters: - other_kwargs['maximize'] = False + if ( + "maximize" + in inspect.signature(torch.optim._functional.adam).parameters + ): + other_kwargs["maximize"] = False torch.optim._functional.adam( [state["_param_fp32"]], [grad], [state["exp_avg"]], [state["exp_avg_sq"]], [], - [state["step"]] if check_torch_version("1.12.0") < 0 - else [torch.tensor(state["step"])], + ( + [state["step"]] + if check_torch_version("1.12.0") < 0 + else [torch.tensor(state["step"])] + ), amsgrad=False, beta1=beta1, beta2=beta2, @@ -136,7 +190,7 @@ def step(self, closure=None, scale=1): state["step"] += 1 else: state["step"] += 1 - if ('maximize' in group) and (group['maximize'] is True): + if ("maximize" in group) and (group["maximize"] is True): grad = -state["_grad_fp16"] else: grad = state["_grad_fp16"] @@ -147,22 +201,23 @@ def step(self, closure=None, scale=1): grad.view(-1), state["exp_avg"].view(-1), state["exp_avg_sq"].view(-1), - beta1, beta2, - eps, 0.0 if state["step"] < self._hold_steps else lr, + beta1, + beta2, + eps, + 0.0 if state["step"] < self._hold_steps else lr, scale, weight_decay, - state["step"] + state["step"], ) total_numel += state["_param_fp16"].numel() if self.record_delta: - sum_delta += param._delta_info[2].item(); - sum_sq_delta += param._delta_info[3].item(); + sum_delta += param._delta_info[2].item() + sum_sq_delta += param._delta_info[3].item() # transfer parameters back to device asynchronously param.copy_(state["_param_fp16"], non_blocking=True) if self.record_delta: self.avg_delta = sum_delta / total_numel - self.var_delta = sum_sq_delta / total_numel - self.avg_delta ** 2 - + self.var_delta = sum_sq_delta / total_numel - self.avg_delta**2 return loss @@ -180,76 +235,96 @@ def load_state_dict(self, state_dict: dict) -> None: from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API - - - + state_dict = deepcopy(state_dict) # Validate the state_dict groups = self.param_groups - saved_groups = state_dict['param_groups'] + saved_groups = state_dict["param_groups"] if len(groups) != len(saved_groups): - raise ValueError("loaded state dict has a different number of " - "parameter groups") - param_lens = (len(g['params']) for g in groups) - saved_lens = (len(g['params']) for g in saved_groups) + raise ValueError( + "loaded state dict has a different number of " "parameter groups" + ) + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): - raise ValueError("loaded state dict contains a parameter group " - "that doesn't match the size of optimizer's group") + raise ValueError( + "loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group" + ) # Update the state - id_map = {old_id: p for old_id, p in - zip(chain.from_iterable((g['params'] for g in saved_groups)), - chain.from_iterable((g['params'] for g in groups)))} + id_map = { + old_id: p + for old_id, p in zip( + chain.from_iterable((g["params"] for g in saved_groups)), + chain.from_iterable((g["params"] for g in groups)), + ) + } - # _param_start_end = chain.from_iterable((g["params_start_end"] for g in saved_groups)) + # _param_start_end = chain.from_iterable((g["params_start_end"] for g in saved_groups)) # Copy state assigned to params (and cast tensors to appropriate types). # State that is not assigned to params is copied as is (needed for # backward compatibility). state = defaultdict(dict) - is_whole = False if "is_whole" not in state_dict else state_dict['is_whole'] + is_whole = False if "is_whole" not in state_dict else state_dict["is_whole"] pop_key = [] - for k, v in state_dict['state'].items(): + for k, v in state_dict["state"].items(): if k in id_map: param = id_map[k] if is_whole and param._start_partition is not None: - for key in ['_param_fp32', 'exp_avg_sq', 'exp_avg']: + for key in ["_param_fp32", "exp_avg_sq", "exp_avg"]: if key in v: - v[key] = v[key][param._start_partition:param._end_partition] + v[key] = v[key][ + param._start_partition : param._end_partition + ] elif is_whole and param._start_partition is None: pop_key.append(param) if "_param_fp32" not in v: with torch.no_grad(): - v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device="cpu") + v["_param_fp32"] = torch.empty( + param.size(), dtype=torch.float32, device="cpu" + ) v["_param_fp32"].copy_(param) - - for name, dtype in [("exp_avg", torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]: + + for name, dtype in [ + ("exp_avg", torch.float32), + ("exp_avg_sq", torch.float32), + ("_param_fp32", torch.float32), + ]: if name in v: v[name] = v[name].to("cpu").to(dtype) state[param] = v if param.dtype == torch.float32: - state[param]["_param_fp32"] = state[param]["_param_fp32"].pin_memory() # on host + state[param]["_param_fp32"] = state[param][ + "_param_fp32" + ].pin_memory() # on host # initialize placeholders - state[param]["_grad_fp32"] = torch.empty(param.size(), dtype=torch.float32, pin_memory=True) # on host + state[param]["_grad_fp32"] = torch.empty( + param.size(), dtype=torch.float32, pin_memory=True + ) # on host else: # initialize placeholders - state[param]["_param_fp16"] = torch.empty(param.size(), dtype=param.dtype, pin_memory=True) # on host - state[param]["_grad_fp16"] = torch.empty(param.size(), dtype=param.dtype, pin_memory=True) # on host + state[param]["_param_fp16"] = torch.empty( + param.size(), dtype=param.dtype, pin_memory=True + ) # on host + state[param]["_grad_fp16"] = torch.empty( + param.size(), dtype=param.dtype, pin_memory=True + ) # on host else: state[k] = v for k in pop_key: state.pop(k) + # Update parameter groups, setting their 'params' value def update_group(group, new_group): - new_group['params'] = group['params'] + new_group["params"] = group["params"] return new_group - param_groups = [ - update_group(g, ng) for g, ng in zip(groups, saved_groups)] - self.__setstate__({'state': state, 'param_groups': param_groups}) - - + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) def state_dict(self, gather=False) -> dict: r"""Returns the state of the optimizer as a :class:`dict`. @@ -261,20 +336,25 @@ def state_dict(self, gather=False) -> dict: * param_groups - a list containing all parameter groups where each parameter group is a dict """ - + # Save order indices instead of Tensors param_mappings = {} start_index = 0 def pack_group(group): nonlocal start_index - packed = {k: v for k, v in group.items() if k != 'params'} - param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index) - if id(p) not in param_mappings}) - packed['params'] = [param_mappings[id(p)] for p in group['params']] - start_index += len(packed['params']) + packed = {k: v for k, v in group.items() if k != "params"} + param_mappings.update( + { + id(p): i + for i, p in enumerate(group["params"], start_index) + if id(p) not in param_mappings + } + ) + packed["params"] = [param_mappings[id(p)] for p in group["params"]] + start_index += len(packed["params"]) return packed - + def cut_states(state): return { "step": state["step"], @@ -282,22 +362,25 @@ def cut_states(state): "exp_avg_sq": state["exp_avg_sq"], "_param_fp32": state["_param_fp32"], } + param_groups = [pack_group(g) for g in self.param_groups] # Remap state to use order indices as keys - packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): cut_states(v) - for k, v in self.state.items()} + packed_state = { + (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): cut_states(v) + for k, v in self.state.items() + } states = { - 'state': packed_state, - 'param_groups': param_groups, + "state": packed_state, + "param_groups": param_groups, } if gather: states = state_dict_gather(states) - states['is_whole'] = True + states["is_whole"] = True else: - states['is_whole'] = False + states["is_whole"] = False return states - - #TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu + + # TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu def zero_grad(self, set_to_none: bool = False): super().zero_grad(set_to_none=set_to_none) diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index a46c7845..21f95f25 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -5,7 +5,13 @@ from .global_var import config -def init_distributed_parameter(params : Iterable[torch.nn.Parameter]): +def init_distributed_parameter(params: Iterable[torch.nn.Parameter]): + """Init param of params which is instance of DistributedParameter using param._init_method. + + Args: + params (Iterable[torch.nn.Parameter]): parameter tensors. + + """ for param in params: if not isinstance(param, DistributedParameter): continue @@ -13,36 +19,50 @@ def init_distributed_parameter(params : Iterable[torch.nn.Parameter]): continue with torch.no_grad(): partition_size = param.storage().size() - global_size = partition_size * config['tp_zero_size'] * config['tp_size'] + global_size = partition_size * config["tp_zero_size"] * config["tp_size"] tmp_storage = param.storage_type()(global_size) tmp_tensor = torch.tensor([], dtype=param.dtype, device="cuda") tmp_tensor.set_(tmp_storage, 0, param._tp_original_shape) param._init_method(tmp_tensor) if param._tp_mode and param._tp_split_dim >= 0: - tensor_list = tmp_tensor.chunk(config['tp_size'], dim=param._tp_split_dim) - sub_tensor = tensor_list[config['topology'].tp_id].contiguous() - tmp_tensor = torch.empty(sub_tensor.shape, device=param.device, dtype=sub_tensor.dtype) + tensor_list = tmp_tensor.chunk( + config["tp_size"], dim=param._tp_split_dim + ) + sub_tensor = tensor_list[config["topology"].tp_id].contiguous() + tmp_tensor = torch.empty( + sub_tensor.shape, device=param.device, dtype=sub_tensor.dtype + ) tmp_tensor.copy_(sub_tensor) if param._tp_mode: - begin = config['tp_zero_rank'] + begin = config["tp_zero_rank"] else: - begin = config['zero_rank'] + begin = config["zero_rank"] end = begin + 1 # Pytorch 1.11 changed the API of storage.__getitem__ - torch.tensor([], dtype=param.dtype, device=param.device).set_(param.storage())[:] = \ - torch.tensor([], dtype=param.dtype, device=param.device).set_(tmp_tensor.storage())[partition_size * begin : partition_size * end] + torch.tensor([], dtype=param.dtype, device=param.device).set_( + param.storage() + )[:] = torch.tensor([], dtype=param.dtype, device=param.device).set_( + tmp_tensor.storage() + )[ + partition_size * begin : partition_size * end + ] # param.storage().copy_(tmp_storage[partition_size * config['rank'] : partition_size * (config['rank'] + 1)]) -def iterate_parameters(model : torch.nn.Module): + +def iterate_parameters(model: torch.nn.Module): + """ + Itterate over the parameters of the model. + """ for kw, val in model._parameters.items(): - if hasattr(val,"_in_block") and val._in_block: + if hasattr(val, "_in_block") and val._in_block: return [] yield val -def init_parameters(model : torch.nn.Module): + +def init_parameters(model: torch.nn.Module): """ Initialize the parameters of the model by calling the init_method of the distributed parameters. """ @@ -52,18 +72,21 @@ def init_parameters(model : torch.nn.Module): if isinstance(module, Block): module.init_parameters() else: - init_distributed_parameter( iterate_parameters(module) ) - + init_distributed_parameter(iterate_parameters(module)) + current_stream = torch.cuda.current_stream() - config['load_stream'].wait_stream(current_stream) + config["load_stream"].wait_stream(current_stream) -def grouped_parameters(model : torch.nn.Module) -> Generator[Tuple[str, List[torch.nn.Parameter]], None, None]: + +def grouped_parameters( + model: torch.nn.Module, +) -> Generator[Tuple[str, List[torch.nn.Parameter]], None, None]: """ Iterate over the parameters of the model grouped by the group name. This is similar to `torch.nn.Module.named_parameters()` . """ - ret : List[torch.nn.Parameter] = {} + ret: List[torch.nn.Parameter] = {} for module in model.modules(): if isinstance(module, Block): for kw, params in module.grouped_parameters(): @@ -80,4 +103,3 @@ def grouped_parameters(model : torch.nn.Module) -> Generator[Tuple[str, List[tor ret[group].append(param) for kw, val in ret.items(): yield kw, val - diff --git a/bmtrain/parameter.py b/bmtrain/parameter.py index ffc27de2..2dad4a3d 100644 --- a/bmtrain/parameter.py +++ b/bmtrain/parameter.py @@ -5,6 +5,7 @@ from . import nccl from .distributed import all_gather + class DistributedParameter(torch.nn.Parameter): r""" DistributedParameter is a subclass of torch.nn.Parameter. @@ -20,41 +21,42 @@ class DistributedParameter(torch.nn.Parameter): **Note**: DistributedParameter must be on the CUDA device. It will transfer the data to device automatically when `__init__` called. """ - - _original_shape : torch.Size - _start_partition : int - _end_partition : int - _init_method : Optional[Callable[['DistributedParameter'], None]] + + _original_shape: torch.Size + _start_partition: int + _end_partition: int + _init_method: Optional[Callable[["DistributedParameter"], None]] _in_block: bool - _group : Optional[str] - - def __new__(cls, - data : torch.Tensor, - requires_grad : bool = True, - init_method : Optional[Callable[['DistributedParameter'], None]] = None, - group : Optional[str] = None, - tp_mode : bool = False, - tp_split_dim : int = -1, - ): + _group: Optional[str] + + def __new__( + cls, + data: torch.Tensor, + requires_grad: bool = True, + init_method: Optional[Callable[["DistributedParameter"], None]] = None, + group: Optional[str] = None, + tp_mode: bool = False, + tp_split_dim: int = -1, + ): if not config["initialized"]: raise RuntimeError("BMTrain is not initialized") num_of_elements = data.numel() - cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda") + cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda") if tp_mode: - comm = config['tp_zero_comm'] + comm = config["tp_zero_comm"] else: - comm = config['zero_comm'] + comm = config["zero_comm"] world_size = nccl.commCount(comm) rank = nccl.commRank(comm) cuda_storage_size = round_up(num_of_elements, world_size) // world_size original_shape = data.size() - tp_original_shape = original_shape + tp_original_shape = original_shape if tp_mode and tp_split_dim >= 0: tp_original_shape = list(original_shape) - tp_original_shape[tp_split_dim] *= config['tp_size'] + tp_original_shape[tp_split_dim] *= config["tp_size"] cuda_storage = cuda_tensor.storage_type()(cuda_storage_size) @@ -65,22 +67,22 @@ def __new__(cls, cuda_tensor_size = max(end_of_partition - start_of_partition, 0) cuda_tensor.set_(cuda_storage, 0, (cuda_tensor_size,)) - cuda_tensor.copy_(data.view(-1)[start_of_partition: end_of_partition]) + cuda_tensor.copy_(data.view(-1)[start_of_partition:end_of_partition]) ret = torch.Tensor._make_subclass(cls, cuda_tensor, requires_grad) - + setattr(ret, "_original_shape", original_shape) setattr(ret, "_start_partition", start_of_partition) setattr(ret, "_end_partition", end_of_partition) setattr(ret, "_init_method", init_method) setattr(ret, "_in_block", False) setattr(ret, "_group", group if not tp_mode else "tp") - + setattr(ret, "_tp_mode", tp_mode) setattr(ret, "_zero_comm", comm) setattr(ret, "_tp_split_dim", tp_split_dim) setattr(ret, "_tp_original_shape", tp_original_shape) return ret - + @property def group(self): """The group name of the distributed parameter.""" @@ -88,52 +90,70 @@ def group(self): return self._group def gather(self) -> torch.Tensor: - """Gather the data from all the distributed nodes. + """Gather the data from ZeRO distributed nodes. Return: torch.Tensor: The gathered data. - + """ - with torch.cuda.stream(config['load_stream']): + with torch.cuda.stream(config["load_stream"]): output_tensor = OpAllGather.apply(self) current_stream = torch.cuda.current_stream() - output_tensor.record_stream( current_stream ) - current_stream.wait_stream(config['load_stream']) + output_tensor.record_stream(current_stream) + current_stream.wait_stream(config["load_stream"]) return output_tensor def gather_all(self) -> torch.tensor: + """Gather the data from ZeRO and Tensor Parallel distributed nodes. + + Return: + torch.Tensor: The gathered data. + + """ zero_param = self.gather() - if config['tp_size'] > 1 and self._tp_split_dim >= 0: - output_tensor = all_gather(zero_param, config['tp_comm']) + if config["tp_size"] > 1 and self._tp_split_dim >= 0: + output_tensor = all_gather(zero_param, config["tp_comm"]) if self._tp_split_dim == 1: - output_list = output_tensor.chunk(config['tp_size'], dim=0) - output = torch.cat(output_list, dim=output_list[0].dim()-1).flatten(0,1) + output_list = output_tensor.chunk(config["tp_size"], dim=0) + output = torch.cat(output_list, dim=output_list[0].dim() - 1).flatten( + 0, 1 + ) return output else: - return output_tensor.flatten(0,1) + return output_tensor.flatten(0, 1) else: return zero_param def tp_gather(self) -> torch.tensor: - if config['tp_size'] > 1 and self._tp_split_dim >= 0: - output_tensor = all_gather(self, config['tp_comm']) + """Gather the data from Tensor Parallel distributed nodes. + + Return: + torch.Tensor: The gathered data. + + """ + if config["tp_size"] > 1 and self._tp_split_dim >= 0: + output_tensor = all_gather(self, config["tp_comm"]) if self._tp_split_dim == 1: - output_list = output_tensor.chunk(config['tp_size'], dim=0) - output = torch.cat(output_list, dim=output_list[0].dim()-1).flatten(0,1) + output_list = output_tensor.chunk(config["tp_size"], dim=0) + output = torch.cat(output_list, dim=output_list[0].dim() - 1).flatten( + 0, 1 + ) return output - else: - return output_tensor.flatten(0,1) + else: + return output_tensor.flatten(0, 1) else: return self - def _copy_data(self, data : torch.Tensor): + def _copy_data(self, data: torch.Tensor): + """Copy data to self.data.""" self.data.copy_(data.view(-1)[self._start_partition : self._end_partition]) - + + class OpAllGather(torch.autograd.Function): @staticmethod - def forward(ctx, value : DistributedParameter): + def forward(ctx, value: DistributedParameter): assert isinstance(value, DistributedParameter) - comm = value._zero_comm #config['zero_comm'] + comm = value._zero_comm # config['zero_comm'] world_size = nccl.commCount(comm) ctx.comm = comm ctx.world_size = world_size @@ -142,22 +162,18 @@ def forward(ctx, value : DistributedParameter): global_size = partition_size * world_size storage = value.storage_type()(global_size) - - nccl.allGather( - value.storage(), - storage, - comm - ) + + nccl.allGather(value.storage(), storage, comm) output_tensor = torch.tensor([], dtype=value.dtype, device="cuda") output_tensor.set_(storage, 0, value._original_shape) - + ctx.partition_size = partition_size ctx.tensor_size = value.size(0) return output_tensor - + @staticmethod - def backward(ctx, grad_output : torch.Tensor): + def backward(ctx, grad_output: torch.Tensor): if not grad_output.is_contiguous(): grad_output = grad_output.contiguous() @@ -167,16 +183,12 @@ def backward(ctx, grad_output : torch.Tensor): pass else: grad_output_storage.resize_(ctx.partition_size * ctx.world_size) - nccl.reduceScatter( - grad_output_storage, - grad_storage, - 'sum', - ctx.comm - ) + nccl.reduceScatter(grad_output_storage, grad_storage, "sum", ctx.comm) grad_tensor = torch.tensor([], dtype=grad_output.dtype, device="cuda") grad_tensor.set_(grad_storage, 0, (ctx.tensor_size,)) return grad_tensor + class ParameterInitializer: """ ParameterInitializer is a helper class that is used to initialize the distributed parameters. @@ -184,10 +196,11 @@ class ParameterInitializer: Similar to functools.partial . """ - def __init__(self, func : Callable, *args, **kwargs) -> None: + + def __init__(self, func: Callable, *args, **kwargs) -> None: self.func = func self._args = args self._kwargs = kwargs - - def __call__(self, param : DistributedParameter): + + def __call__(self, param: DistributedParameter): self.func(param, *self._args, **self._kwargs) diff --git a/bmtrain/synchronize.py b/bmtrain/synchronize.py index 8dd5fa67..87619159 100644 --- a/bmtrain/synchronize.py +++ b/bmtrain/synchronize.py @@ -4,6 +4,7 @@ import warnings from typing import Optional + def synchronize(): """ Synchronize all the workers across all nodes. (both CPU and GPU are synchronized) @@ -11,48 +12,61 @@ def synchronize(): if not config["initialized"]: raise RuntimeError("BMTrain is not initialized") - with torch.cuda.stream(config['barrier_stream']): + with torch.cuda.stream(config["barrier_stream"]): barrier = torch.cuda.FloatTensor([1]) - nccl.allReduce(barrier.storage(), barrier.storage(), 'sum', config['comm']) - config['barrier_stream'].synchronize() + nccl.allReduce(barrier.storage(), barrier.storage(), "sum", config["comm"]) + config["barrier_stream"].synchronize() + def wait_loader(): + """ + Clac_stream (normally current stream) wait latest loader event, and set a new one. + """ if not config["initialized"]: raise RuntimeError("BMTrain is not initialized") - # wait lastest loader event, and set a new one - config['load_event'].synchronize() - config['calc_stream'].record_event(config['load_event']) + config["load_event"].synchronize() + config["calc_stream"].record_event(config["load_event"]) -def sum_loss(loss : torch.Tensor, comm: Optional[nccl.NCCLCommunicator] = None): +def sum_loss(loss: torch.Tensor, comm: Optional[nccl.NCCLCommunicator] = None): """ Sum the loss across all workers. This is a helper function to reduce the loss across all workers. """ if comm is None: - comm = config['comm'] - warnings.warn("bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead.", DeprecationWarning) + comm = config["comm"] + warnings.warn( + "bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead.", + DeprecationWarning, + ) + + return distributed.all_reduce(loss, "avg", comm) - return distributed.all_reduce(loss, "avg", comm) def gather_result(result: torch.Tensor): - warnings.warn("bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", DeprecationWarning) + """ + Gather result across all workers. + """ + warnings.warn( + "bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", + DeprecationWarning, + ) if result.storage_offset() != 0 or result.storage().size() != result.numel(): # Create a clone of the original tensor if it's a slice result = result.clone() - + output_cuda = True if not result.is_cuda: result = result.cuda() output_cuda = False - ret = torch.empty((result.shape[0]*config['world_size'], *list(result.shape[1:])), device=result.device, dtype=result.dtype) - nccl.allGather( - result.storage(), - ret.storage(), - config['comm'] + ret = torch.empty( + (result.shape[0] * config["world_size"], *list(result.shape[1:])), + device=result.device, + dtype=result.dtype, ) + nccl.allGather(result.storage(), ret.storage(), config["comm"]) if output_cuda: return ret else: diff --git a/bmtrain/utils.py b/bmtrain/utils.py index 8cb87808..daa4c595 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -7,6 +7,8 @@ ALIGN = 4 ROW_WIDTH = 60 + + def check_torch_version(version_str): """ Checks if the current torch version is greater than or equal to the given version. @@ -14,13 +16,23 @@ def check_torch_version(version_str): """ version_int_arr = [int(v) for v in version_str.split(".")] - version_int = version_int_arr[0] * 10000 + version_int_arr[1] * 100 + version_int_arr[2] + version_int = ( + version_int_arr[0] * 10000 + version_int_arr[1] * 100 + version_int_arr[2] + ) torch_version = torch.__version__.split("+")[0] current_version_int_arr = [int(v) for v in torch_version.split(".")] - current_version_int = current_version_int_arr[0] * 10000 + current_version_int_arr[1] * 100 + current_version_int_arr[2] + current_version_int = ( + current_version_int_arr[0] * 10000 + + current_version_int_arr[1] * 100 + + current_version_int_arr[2] + ) return current_version_int - version_int + def load_nccl_pypi(): + """ + Check if current nccl is avaliable. + """ try: import nvidia.nccl except: @@ -28,16 +40,23 @@ def load_nccl_pypi(): path = os.path.join(os.path.dirname(nvidia.nccl.__file__), "lib") for file_so in os.listdir(path): - file_split = file_so.split('.') - if file_split[-1] == "so" or (len(file_split)>1 and file_split[-2] == "so"): + file_split = file_so.split(".") + if file_split[-1] == "so" or (len(file_split) > 1 and file_split[-2] == "so"): ctypes.CDLL(os.path.join(path, file_so)) - - + + def round_up(x, d): + """ + Return (x + d - 1) // d * d + """ return (x + d - 1) // d * d -def print_dict(title : str, content : Dict[str, Any], file=sys.stdout): - max_kw_len = max([ len(kw) for kw in content.keys() ]) + +def print_dict(title: str, content: Dict[str, Any], file=sys.stdout): + """ + Print Dict to file. + """ + max_kw_len = max([len(kw) for kw in content.keys()]) max_kw_len = round_up(max_kw_len + 3, 4) raw_content = "" @@ -45,7 +64,7 @@ def print_dict(title : str, content : Dict[str, Any], file=sys.stdout): for kw, val in content.items(): raw_content += kw + " :" + " " * (max_kw_len - len(kw) - 2) raw_val = "%s" % val - + len_val_row = ROW_WIDTH - max_kw_len st = 0 if len(raw_val) == 0: @@ -53,20 +72,24 @@ def print_dict(title : str, content : Dict[str, Any], file=sys.stdout): while st < len(raw_val): if st > 0: raw_content += " " * max_kw_len - raw_content += raw_val[st: st + len_val_row] + "\n" + raw_content += raw_val[st : st + len_val_row] + "\n" st += len_val_row - + print_block(title, raw_content, file) -def print_block(title : str, content : Optional[str] = None, file=sys.stdout): +def print_block(title: str, content: Optional[str] = None, file=sys.stdout): + """ + Print content to file. + """ left_title = (ROW_WIDTH - len(title) - 2) // 2 right_title = ROW_WIDTH - len(title) - 2 - left_title - + print("=" * left_title + " " + title + " " + "=" * right_title, file=file) if content is not None: print(content, file=file) - + + def print_rank(*args, rank=0, **kwargs): """ Prints the message only on the `rank` of the process. @@ -80,6 +103,7 @@ def print_rank(*args, rank=0, **kwargs): if config["rank"] == rank: print(*args, **kwargs) + def see_memory(message, detail=False): """ Outputs a message followed by GPU memory status summary on rank 0. @@ -93,43 +117,58 @@ def see_memory(message, detail=False): >>> bmt.see_memory("before forward") >>> # forward_step() >>> bmt.see_memory("after forward") - + """ print_rank(message) if detail: print_rank(torch.cuda.memory_summary()) else: - print_rank(f""" + print_rank( + f""" ======================================================================================= memory_allocated {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB max_memory_allocated {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB ======================================================================================= - """) + """ + ) torch.cuda.reset_peak_memory_stats() + def tp_split_tensor(tensor, split_dim): - tensor_list = tensor.chunk(config['tp_size'], dim=split_dim) - sub_tensor = tensor_list[config['topology'].tp_id].contiguous() - tmp_tensor = torch.empty(sub_tensor.shape, device=sub_tensor.device, dtype=sub_tensor.dtype) + """ + Outpus the tensor with config["toplogy"].tp_id split at split dim. + + Args: + tensor (torch.tensor): The tensor to be splited. + split_dim (int): The dim to split the input tensor. + + """ + tensor_list = tensor.chunk(config["tp_size"], dim=split_dim) + sub_tensor = tensor_list[config["topology"].tp_id].contiguous() + tmp_tensor = torch.empty( + sub_tensor.shape, device=sub_tensor.device, dtype=sub_tensor.dtype + ) tmp_tensor.copy_(sub_tensor) return tmp_tensor + class AverageRecorder: """A utility class to record the average value of a quantity over time. Args: alpha (float): The decay factor of the average. start_value (float): The initial value of the average. - + Use `.value` to get the current average value. It is calculated as `alpha * old_value + (1 - alpha) * new_value`. - + """ - def __init__(self, alpha = 0.9, start_value = 0): + + def __init__(self, alpha=0.9, start_value=0): self._value = start_value self.alpha = alpha self._steps = 0 - + def record(self, v): """Records a new value. Args: @@ -137,7 +176,7 @@ def record(self, v): """ self._value = self._value * self.alpha + v * (1 - self.alpha) self._steps += 1 - + @property def value(self): if self._steps <= 0: diff --git a/bmtrain/wrapper.py b/bmtrain/wrapper.py index 722a8037..93d9877a 100644 --- a/bmtrain/wrapper.py +++ b/bmtrain/wrapper.py @@ -2,16 +2,20 @@ from .block_layer import Block, TransformerBlockList from .layer import DistributedModule, DistributedParameter -def make_distributed(model : torch.nn.Module): + +def make_distributed(model: torch.nn.Module): for kw in list(model._parameters.keys()): if model._parameters[kw] is not None: if not isinstance(model._parameters[kw], DistributedParameter): - model._parameters[kw] = DistributedParameter(model._parameters[kw], requires_grad=model._parameters[kw].requires_grad) - + model._parameters[kw] = DistributedParameter( + model._parameters[kw], + requires_grad=model._parameters[kw].requires_grad, + ) + for kw in list(model._buffers.keys()): if model._buffers[kw] is not None: model._buffers[kw] = model._buffers[kw].cuda() - + for kw in list(model._modules.keys()): if isinstance(model, torch.nn.ModuleList): if not isinstance(model._modules[kw], Block): @@ -19,10 +23,15 @@ def make_distributed(model : torch.nn.Module): else: model._modules[kw] = model_wrapper_dispatch(model._modules[kw]) - model.__class__ = type("bmtrain.Distributed" + model.__class__.__name__, (model.__class__, DistributedModule), {}) + model.__class__ = type( + "bmtrain.Distributed" + model.__class__.__name__, + (model.__class__, DistributedModule), + {}, + ) return model -def model_wrapper_dispatch(model : torch.nn.Module): + +def model_wrapper_dispatch(model: torch.nn.Module): if isinstance(model, TransformerBlockList): return model elif isinstance(model, DistributedModule): @@ -32,7 +41,8 @@ def model_wrapper_dispatch(model : torch.nn.Module): else: return make_distributed(model) -def BMTrainModelWrapper(model : torch.nn.Module) -> torch.nn.Module: + +def BMTrainModelWrapper(model: torch.nn.Module) -> torch.nn.Module: """ Automatically wrap a model in a BMTrain model. Replaces all parameters with DistributedParameter, all modules with DistributedModule, and modules in ModuleList with Block. diff --git a/bmtrain/zero_context.py b/bmtrain/zero_context.py index 653f40fa..8a74b3f8 100644 --- a/bmtrain/zero_context.py +++ b/bmtrain/zero_context.py @@ -3,8 +3,19 @@ from .global_var import config from .synchronize import wait_loader + class ZeroContext: - def __init__(self, block : 'Block', ctx_dict : dict = None, pipe = False) -> None: + """ZeroContext is a helper class to Gather parameters before module forward and reduce scatter + gradients after module backward. + + Args: + block (BLock): Input Block. + ctx_dict (dict): block._layer_dict. + pipe (bool): True if use pipe parallel. + + """ + + def __init__(self, block: "Block", ctx_dict: dict = None, pipe=False) -> None: self.block = block self.ctx_dict = ctx_dict self._param_buffer = {} @@ -15,7 +26,7 @@ def __init__(self, block : 'Block', ctx_dict : dict = None, pipe = False) -> Non def enter(self, flag=0, requires_grad=False): """ - gather parameters + Gather parameters before module forward and init grad buffer before backward. """ if self.block._ready: return @@ -28,29 +39,45 @@ def enter(self, flag=0, requires_grad=False): assert self.block._storage_params[kw].is_cuda assert kw not in self._grad_buffer assert kw not in self._param_buffer - local_param = self.block._storage_params[kw] - + local_param = self.block._storage_params[kw] + storage_type = local_param.storage_type() if flag != 2: - self._param_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) - self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw]) + self._param_buffer[kw] = storage_type( + val["partition_size"] * val["world_size"] + ) + self._param_tensor[kw] = torch.tensor( + [], + dtype=self._param_buffer[kw].dtype, + device=self._param_buffer[kw].device, + ).set_(self._param_buffer[kw]) if requires_grad and local_param.requires_grad: - self._grad_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) - self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_() + self._grad_buffer[kw] = storage_type( + val["partition_size"] * val["world_size"] + ) + self._grad_tensor[kw] = ( + torch.tensor( + [], + dtype=self._grad_buffer[kw].dtype, + device=self._grad_buffer[kw].device, + ) + .set_(self._grad_buffer[kw]) + .zero_() + ) if flag != 2: nccl.groupStart() for kw, val in self.block._storage_info.items(): nccl.allGather( self.block._storage_params[kw].storage(), self._param_buffer[kw], - val['zero_comm'] + val["zero_comm"], ) nccl.groupEnd() current_stream = torch.cuda.current_stream() current_stream.wait_stream(config["load_stream"]) - + # set wait stream for each storage for kw in self.block._storage_info.keys(): if flag != 2: @@ -67,23 +94,32 @@ def enter(self, flag=0, requires_grad=False): if flag != 2: dtype = self._param_buffer[kw_name].dtype device = self._param_buffer[kw_name].device - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) + param["parameter"].data = torch.tensor( + [], dtype=dtype, device=device + ).set_(self._param_buffer[kw_name], offset, shape) else: dtype = param["parameter"].data.dtype device = param["parameter"].data.device - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.ctx_dict[kw_name], offset, shape) - - if requires_grad and kw_name in self._grad_buffer and param["parameter"].requires_grad: - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape) + param["parameter"].data = torch.tensor( + [], dtype=dtype, device=device + ).set_(self.ctx_dict[kw_name], offset, shape) + + if ( + requires_grad + and kw_name in self._grad_buffer + and param["parameter"].requires_grad + ): + param["parameter"].grad = torch.tensor( + [], dtype=dtype, device=device + ).set_(self._grad_buffer[kw_name], offset, shape) def __enter__(self): self.enter() - + def exit(self, flag=0, backward=False): """ - Reduce scatter gradients + Reduce scatter gradients when backward and release all parameters from buffer to block_storge when forward is done. """ - if not self._need_release: return self._need_release = False @@ -95,13 +131,23 @@ def exit(self, flag=0, backward=False): # accumulate previous gradient if local_param.requires_grad: if local_param.grad is None: - grad_storage = val["storage_type"](val["partition_size"]) # initialize gradient if not exist - local_param.grad = torch.tensor([], dtype=grad_storage.dtype, device=grad_storage.device).set_(grad_storage).zero_() + grad_storage = val["storage_type"]( + val["partition_size"] + ) # initialize gradient if not exist + local_param.grad = ( + torch.tensor( + [], dtype=grad_storage.dtype, device=grad_storage.device + ) + .set_(grad_storage) + .zero_() + ) else: - self._grad_tensor[kw][val["begin"]:val["end"]] += local_param.grad - + self._grad_tensor[kw][ + val["begin"] : val["end"] + ] += local_param.grad + current_stream = torch.cuda.current_stream() - config["load_stream"].wait_stream(current_stream) # wait for backward + config["load_stream"].wait_stream(current_stream) # wait for backward with torch.cuda.stream(config["load_stream"]): nccl.groupStart() @@ -114,7 +160,7 @@ def exit(self, flag=0, backward=False): self._grad_buffer[kw], local_param.grad.storage(), "sum", - val['zero_comm'] + val["zero_comm"], ) nccl.groupEnd() @@ -123,7 +169,6 @@ def exit(self, flag=0, backward=False): # grads can not be freed until reduce ops finish self._grad_tensor[kw].record_stream(config["load_stream"]) - # Release all parameters from buffer to block_storge for param in self.block._param_info: kw_name = param["kw_name"] @@ -135,9 +180,16 @@ def exit(self, flag=0, backward=False): continue begin = param["begin"] end = param["end"] - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) - if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None: - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_( + self.block._storage_params[kw_name].storage(), begin, end + ) + if ( + param["parameter"].requires_grad + and self.block._storage_params[kw_name].grad is not None + ): + param["parameter"].grad = torch.tensor( + [], dtype=dtype, device=device + ).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) if flag == 1: for i in self._param_buffer: self.ctx_dict[i] = self._param_buffer[i] @@ -145,7 +197,7 @@ def exit(self, flag=0, backward=False): self._param_tensor = {} self._grad_buffer = {} self._param_buffer = {} - + def __exit__(self, exc_type, exc_val, exc_tb): # reduce scatter gradients self.exit() diff --git a/docs/Makefile b/docs/Makefile index d0c3cbf1..4f2fbe66 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -5,7 +5,7 @@ # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build -SOURCEDIR = source +SOURCEDIR = source-en BUILDDIR = build # Put it first so that "make" without argument is like "make help". diff --git a/docs/source-en/api/bmtrain.benchmark.rst_bk b/docs/source-en/api/bmtrain.benchmark.rst_bk new file mode 100644 index 00000000..f8b2902d --- /dev/null +++ b/docs/source-en/api/bmtrain.benchmark.rst_bk @@ -0,0 +1,53 @@ +bmtrain.benchmark package +========================= + +Submodules +---------- + +bmtrain.benchmark.all\_gather module +------------------------------------ + +.. automodule:: bmtrain.benchmark.all_gather + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.reduce\_scatter module +---------------------------------------- + +.. automodule:: bmtrain.benchmark.reduce_scatter + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.send\_recv module +----------------------------------- + +.. automodule:: bmtrain.benchmark.send_recv + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.shape module +------------------------------ + +.. automodule:: bmtrain.benchmark.shape + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.utils module +------------------------------ + +.. automodule:: bmtrain.benchmark.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.benchmark + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source-en/api/bmtrain.distributed.rst_bk b/docs/source-en/api/bmtrain.distributed.rst_bk new file mode 100644 index 00000000..ef41db07 --- /dev/null +++ b/docs/source-en/api/bmtrain.distributed.rst_bk @@ -0,0 +1,21 @@ +bmtrain.distributed package +=========================== + +Submodules +---------- + +bmtrain.distributed.ops module +------------------------------ + +.. automodule:: bmtrain.distributed.ops + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.distributed + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source-en/api/bmtrain.inspect.rst b/docs/source-en/api/bmtrain.inspect.rst new file mode 100644 index 00000000..c57195ad --- /dev/null +++ b/docs/source-en/api/bmtrain.inspect.rst @@ -0,0 +1,37 @@ +bmtrain.inspect package +======================= + +Submodules +---------- + +bmtrain.inspect.format module +----------------------------- + +.. automodule:: bmtrain.inspect.format + :members: format_summary + :undoc-members: + :show-inheritance: + +bmtrain.inspect.model module +---------------------------- + +.. automodule:: bmtrain.inspect.model + :members: inspect_model + :undoc-members: + :show-inheritance: + +bmtrain.inspect.tensor module +----------------------------- + +.. automodule:: bmtrain.inspect.tensor + :members: inspect_tensor, InspectTensor + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.inspect + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source-en/api/bmtrain.loss.rst b/docs/source-en/api/bmtrain.loss.rst new file mode 100644 index 00000000..03b65646 --- /dev/null +++ b/docs/source-en/api/bmtrain.loss.rst @@ -0,0 +1,21 @@ +bmtrain.loss package +==================== + +Submodules +---------- + +bmtrain.loss.cross\_entropy module +---------------------------------- + +.. automodule:: bmtrain.loss.cross_entropy + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.loss + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source-en/api/bmtrain.lr_scheduler.rst b/docs/source-en/api/bmtrain.lr_scheduler.rst new file mode 100644 index 00000000..0ba033af --- /dev/null +++ b/docs/source-en/api/bmtrain.lr_scheduler.rst @@ -0,0 +1,61 @@ +bmtrain.lr\_scheduler package +============================= + +Submodules +---------- + +bmtrain.lr\_scheduler.cosine module +----------------------------------- + +.. automodule:: bmtrain.lr_scheduler.cosine + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.exponential module +---------------------------------------- + +.. automodule:: bmtrain.lr_scheduler.exponential + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.linear module +----------------------------------- + +.. automodule:: bmtrain.lr_scheduler.linear + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.no\_decay module +-------------------------------------- + +.. automodule:: bmtrain.lr_scheduler.no_decay + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.noam module +--------------------------------- + +.. automodule:: bmtrain.lr_scheduler.noam + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.warmup module +----------------------------------- + +.. automodule:: bmtrain.lr_scheduler.warmup + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.lr_scheduler + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source-en/api/bmtrain.nccl.rst_bk b/docs/source-en/api/bmtrain.nccl.rst_bk new file mode 100644 index 00000000..3755d9ef --- /dev/null +++ b/docs/source-en/api/bmtrain.nccl.rst_bk @@ -0,0 +1,21 @@ +bmtrain.nccl package +==================== + +Submodules +---------- + +bmtrain.nccl.enums module +------------------------- + +.. automodule:: bmtrain.nccl.enums + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.nccl + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source-en/api/bmtrain.nn.rst b/docs/source-en/api/bmtrain.nn.rst new file mode 100644 index 00000000..8e2a531f --- /dev/null +++ b/docs/source-en/api/bmtrain.nn.rst @@ -0,0 +1,53 @@ +bmtrain.nn package +================== + +Submodules +---------- + +bmtrain.nn.column\_parallel\_linear module +------------------------------------------ + +.. automodule:: bmtrain.nn.column_parallel_linear + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.linear module +------------------------ + +.. automodule:: bmtrain.nn.linear + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.parallel\_embedding module +------------------------------------- + +.. automodule:: bmtrain.nn.parallel_embedding + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.parallel\_linear\_func module +---------------------------------------- + +.. automodule:: bmtrain.nn.parallel_linear_func + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.row\_parallel\_linear module +--------------------------------------- + +.. automodule:: bmtrain.nn.row_parallel_linear + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.nn + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source-en/api/bmtrain.optim.rst b/docs/source-en/api/bmtrain.optim.rst new file mode 100644 index 00000000..2d47a3dd --- /dev/null +++ b/docs/source-en/api/bmtrain.optim.rst @@ -0,0 +1,37 @@ +bmtrain.optim package +===================== + +Submodules +---------- + +bmtrain.optim.adam module +------------------------- + +.. automodule:: bmtrain.optim.adam + :members: + :undoc-members: + :show-inheritance: + +bmtrain.optim.adam\_offload module +---------------------------------- + +.. automodule:: bmtrain.optim.adam_offload + :members: + :undoc-members: + :show-inheritance: + +bmtrain.optim.optim\_manager module +----------------------------------- + +.. automodule:: bmtrain.optim.optim_manager + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.optim + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source-en/api/bmtrain.rst b/docs/source-en/api/bmtrain.rst index a1d742d6..8445e5f0 100644 --- a/docs/source-en/api/bmtrain.rst +++ b/docs/source-en/api/bmtrain.rst @@ -1,58 +1,140 @@ -==================== -bmtrain -==================== +bmtrain package +=============== +Subpackages +----------- -Initialization -========================================== +.. toctree:: + :maxdepth: 4 -.. autofunction:: bmtrain.init_distributed + bmtrain.benchmark + bmtrain.distributed + bmtrain.inspect + bmtrain.loss + bmtrain.lr_scheduler + bmtrain.nccl + bmtrain.nn + bmtrain.optim -Distributed Parameters and Modules -========================================== +Submodules +---------- -.. autoclass:: bmtrain.DistributedParameter - :members: - :show-inheritance: +bmtrain.block\_layer module +--------------------------- + +.. automodule:: bmtrain.block_layer + :members: + :undoc-members: + :show-inheritance: -.. autoclass:: bmtrain.ParameterInitializer - :members: - :show-inheritance: +.. bmtrain.debug module +.. -------------------- + +.. .. automodule:: bmtrain.debug +.. :members: +.. :undoc-members: +.. :show-inheritance: -.. autoclass:: bmtrain.DistributedModule - :members: - :show-inheritance: +bmtrain.global\_var module +-------------------------- + +.. automodule:: bmtrain.global_var + :members: + :undoc-members: + :show-inheritance: + +.. bmtrain.hook\_func module +.. ------------------------- + +.. .. automodule:: bmtrain.hook_func +.. :members: +.. :undoc-members: +.. :show-inheritance: + +bmtrain.init module +------------------- + +.. automodule:: bmtrain.init + :members: + :undoc-members: + :show-inheritance: + +bmtrain.layer module +-------------------- + +.. automodule:: bmtrain.layer + :members: + :undoc-members: + :show-inheritance: + +bmtrain.param\_init module +-------------------------- + +.. automodule:: bmtrain.param_init + :members: + :undoc-members: + :show-inheritance: + +bmtrain.parameter module +------------------------ + +.. automodule:: bmtrain.parameter + :members: DistributedParameter, ParameterInitializer + :undoc-members: + :show-inheritance: + +bmtrain.pipe\_layer module +-------------------------- + +.. automodule:: bmtrain.pipe_layer + :members: PipelineTransformerBlockList + :undoc-members: + :show-inheritance: + +bmtrain.store module +-------------------- + +.. automodule:: bmtrain.store + :members: save, load + :undoc-members: + :show-inheritance: + +bmtrain.synchronize module +-------------------------- + +.. automodule:: bmtrain.synchronize + :members: + :undoc-members: + :show-inheritance: + +bmtrain.utils module +-------------------- + +.. automodule:: bmtrain.utils + :members: + :undoc-members: + :show-inheritance: -.. autoclass:: bmtrain.CheckpointBlock - :members: - :show-inheritance: +bmtrain.wrapper module +---------------------- -.. autoclass:: bmtrain.TransformerBlockList - :members: - :show-inheritance: +.. automodule:: bmtrain.wrapper + :members: BMTrainModelWrapper + :undoc-members: + :show-inheritance: + +bmtrain.zero\_context module +---------------------------- -Methods for Parameters -========================================== +.. automodule:: bmtrain.zero_context + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- -.. autofunction:: bmtrain.init_parameters - -.. autofunction:: bmtrain.grouped_parameters - -.. autofunction:: bmtrain.save - -.. autofunction:: bmtrain.load - -Utilities -========================================== - -.. autofunction:: bmtrain.rank - -.. autofunction:: bmtrain.world_size - -.. autofunction:: bmtrain.print_rank - -.. autofunction:: bmtrain.synchronize - -.. autofunction:: bmtrain.sum_loss - -.. autofunction:: bmtrain.optim_step +.. automodule:: bmtrain + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source-en/api/inspect.rst b/docs/source-en/api/inspect.rst deleted file mode 100644 index 955fd347..00000000 --- a/docs/source-en/api/inspect.rst +++ /dev/null @@ -1,15 +0,0 @@ -==================== -bmtrain.inspect -==================== - -The `bmtrain.inspect` module is a module for debugging and analysis of distributed code. - -We recommend that you use the tools in this module to obtain the parameters and computing results in distributed models. - -.. autofunction:: bmtrain.inspect.inspect_model - -.. autofunction:: bmtrain.inspect.inspect_tensor - -.. autofunction:: bmtrain.inspect.record_tensor - -.. autofunction:: bmtrain.inspect.format_summary diff --git a/docs/source-en/api/lr_scheduler.rst b/docs/source-en/api/lr_scheduler.rst deleted file mode 100644 index e05f2651..00000000 --- a/docs/source-en/api/lr_scheduler.rst +++ /dev/null @@ -1,34 +0,0 @@ -====================== -bmtrain.lr_scheduler -====================== - -The `bmtrain.lr_scheduler` module provides the common learning rate schedulers for big model training. - - -.. autoclass:: bmtrain.lr_scheduler.WarmupLRScheduler - :members: - - -LR Schedulers -========================= - -.. autoclass:: bmtrain.lr_scheduler.NoDecay - :members: - :show-inheritance: - -.. autoclass:: bmtrain.lr_scheduler.Noam - :members: - :show-inheritance: - -.. autoclass:: bmtrain.lr_scheduler.Linear - :members: - :show-inheritance: - -.. autoclass:: bmtrain.lr_scheduler.Exponential - :members: - :show-inheritance: - -.. autoclass:: bmtrain.lr_scheduler.Cosine - :members: - :show-inheritance: - diff --git a/docs/source-en/api/modules.rst b/docs/source-en/api/modules.rst new file mode 100644 index 00000000..4350b5d7 --- /dev/null +++ b/docs/source-en/api/modules.rst @@ -0,0 +1,7 @@ +bmtrain +======= + +.. toctree:: + :maxdepth: 4 + + bmtrain diff --git a/docs/source-en/api/nccl.rst b/docs/source-en/api/nccl.rst deleted file mode 100644 index f32e7ba9..00000000 --- a/docs/source-en/api/nccl.rst +++ /dev/null @@ -1,27 +0,0 @@ -======================= -bmtrain.nccl -======================= - -.. autoclass:: bmtrain.nccl.NCCLCommunicator - :members: - :show-inheritance: - -.. autofunction:: bmtrain.nccl.getUniqueId - -.. autofunction:: bmtrain.nccl.commInitRank - -.. autofunction:: bmtrain.nccl.commDestroy - -.. autofunction:: bmtrain.nccl.allReduce - -.. autofunction:: bmtrain.nccl.broadcast - -.. autofunction:: bmtrain.nccl.reduce - -.. autofunction:: bmtrain.nccl.allGather - -.. autofunction:: bmtrain.nccl.reduceScatter - -.. autofunction:: bmtrain.nccl.groupStart - -.. autofunction:: bmtrain.nccl.groupEnd diff --git a/docs/source-en/conf.py b/docs/source-en/conf.py index c782f22c..6351767a 100644 --- a/docs/source-en/conf.py +++ b/docs/source-en/conf.py @@ -12,43 +12,50 @@ # import os import sys -sys.path.insert(0, os.path.abspath('../..')) + +sys.path.insert(0, os.path.abspath("../../..")) import recommonmark from recommonmark.transform import AutoStructify - # -- Project information ----------------------------------------------------- -project = 'BMTrain' -copyright = '2022, OpenBMB' -author = 'BMTrain Team' -autodoc_mock_imports = ["numpy", "tensorboard", "bmtrain.nccl._C", "bmtrain.optim._cpu", "bmtrain.optim._cuda", "bmtrain.loss._cuda"] +project = "BMTrain" +copyright = "2022, OpenBMB" +author = "BMTrain Team" +autodoc_mock_imports = [ + "numpy", + "tensorboard", + "bmtrain.nccl._C", + "bmtrain.optim._cpu", + "bmtrain.optim._cuda", + "bmtrain.loss._cuda", +] # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.napoleon', - 'sphinx.ext.mathjax', - 'recommonmark', - 'sphinx_markdown_tables', + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx.ext.mathjax", + "recommonmark", + "sphinx_markdown_tables", ] -source_suffix = ['.rst', '.md'] +source_suffix = [".rst", ".md"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = 'en' +language = "en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -61,12 +68,12 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] -#html_stype="css/custom.css" -html_css_files=['css/custom.css' ] -html_js_files= ['js/custom.js' ] +html_static_path = ["_static"] +# html_stype="css/custom.css" +html_css_files = ["css/custom.css"] +html_js_files = ["js/custom.js"] diff --git a/docs/source-en/index.rst b/docs/source-en/index.rst index 8b2e7982..a2ec24f9 100644 --- a/docs/source-en/index.rst +++ b/docs/source-en/index.rst @@ -3,10 +3,10 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -BMTrain's Documentation! -============================= +BMTrain 文档 +=============================== -**BMTrain** is an efficient large model training toolkit that can be used to train large models with tens of billions of parameters. It can train models in a distributed manner while keeping the code as simple as stand-alone training. +**BMTrain** 是一个高效的大模型训练工具包,可以用于训练数百亿参数的大模型。BMTrain 可以在分布式训练模型的同时,能够保持代码的简洁性。 ======================================= @@ -23,9 +23,13 @@ BMTrain's Documentation! :caption: Package Reference api/bmtrain.rst - api/nccl.rst - api/inspect.rst - api/lr_scheduler.rst + api/bmtrain.benchmark.rst + api/bmtrain.distributed.rst + api/bmtrain.inspect.rst + api/bmtrain.loss.rst + api/bmtrain.lr_scheduler.rst + api/bmtrain.nccl.rst + api/bmtrain.optim.rst API ================== diff --git a/docs/source/api/bmtrain.benchmark.rst_bk b/docs/source/api/bmtrain.benchmark.rst_bk new file mode 100644 index 00000000..f8b2902d --- /dev/null +++ b/docs/source/api/bmtrain.benchmark.rst_bk @@ -0,0 +1,53 @@ +bmtrain.benchmark package +========================= + +Submodules +---------- + +bmtrain.benchmark.all\_gather module +------------------------------------ + +.. automodule:: bmtrain.benchmark.all_gather + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.reduce\_scatter module +---------------------------------------- + +.. automodule:: bmtrain.benchmark.reduce_scatter + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.send\_recv module +----------------------------------- + +.. automodule:: bmtrain.benchmark.send_recv + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.shape module +------------------------------ + +.. automodule:: bmtrain.benchmark.shape + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.utils module +------------------------------ + +.. automodule:: bmtrain.benchmark.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.benchmark + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/bmtrain.distributed.rst_bk b/docs/source/api/bmtrain.distributed.rst_bk new file mode 100644 index 00000000..ef41db07 --- /dev/null +++ b/docs/source/api/bmtrain.distributed.rst_bk @@ -0,0 +1,21 @@ +bmtrain.distributed package +=========================== + +Submodules +---------- + +bmtrain.distributed.ops module +------------------------------ + +.. automodule:: bmtrain.distributed.ops + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.distributed + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/bmtrain.inspect.rst b/docs/source/api/bmtrain.inspect.rst new file mode 100644 index 00000000..c57195ad --- /dev/null +++ b/docs/source/api/bmtrain.inspect.rst @@ -0,0 +1,37 @@ +bmtrain.inspect package +======================= + +Submodules +---------- + +bmtrain.inspect.format module +----------------------------- + +.. automodule:: bmtrain.inspect.format + :members: format_summary + :undoc-members: + :show-inheritance: + +bmtrain.inspect.model module +---------------------------- + +.. automodule:: bmtrain.inspect.model + :members: inspect_model + :undoc-members: + :show-inheritance: + +bmtrain.inspect.tensor module +----------------------------- + +.. automodule:: bmtrain.inspect.tensor + :members: inspect_tensor, InspectTensor + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.inspect + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/bmtrain.loss.rst b/docs/source/api/bmtrain.loss.rst new file mode 100644 index 00000000..03b65646 --- /dev/null +++ b/docs/source/api/bmtrain.loss.rst @@ -0,0 +1,21 @@ +bmtrain.loss package +==================== + +Submodules +---------- + +bmtrain.loss.cross\_entropy module +---------------------------------- + +.. automodule:: bmtrain.loss.cross_entropy + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.loss + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/bmtrain.lr_scheduler.rst b/docs/source/api/bmtrain.lr_scheduler.rst new file mode 100644 index 00000000..0ba033af --- /dev/null +++ b/docs/source/api/bmtrain.lr_scheduler.rst @@ -0,0 +1,61 @@ +bmtrain.lr\_scheduler package +============================= + +Submodules +---------- + +bmtrain.lr\_scheduler.cosine module +----------------------------------- + +.. automodule:: bmtrain.lr_scheduler.cosine + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.exponential module +---------------------------------------- + +.. automodule:: bmtrain.lr_scheduler.exponential + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.linear module +----------------------------------- + +.. automodule:: bmtrain.lr_scheduler.linear + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.no\_decay module +-------------------------------------- + +.. automodule:: bmtrain.lr_scheduler.no_decay + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.noam module +--------------------------------- + +.. automodule:: bmtrain.lr_scheduler.noam + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.warmup module +----------------------------------- + +.. automodule:: bmtrain.lr_scheduler.warmup + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.lr_scheduler + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/bmtrain.nccl.rst_bk b/docs/source/api/bmtrain.nccl.rst_bk new file mode 100644 index 00000000..3755d9ef --- /dev/null +++ b/docs/source/api/bmtrain.nccl.rst_bk @@ -0,0 +1,21 @@ +bmtrain.nccl package +==================== + +Submodules +---------- + +bmtrain.nccl.enums module +------------------------- + +.. automodule:: bmtrain.nccl.enums + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.nccl + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/bmtrain.nn.rst b/docs/source/api/bmtrain.nn.rst new file mode 100644 index 00000000..8e2a531f --- /dev/null +++ b/docs/source/api/bmtrain.nn.rst @@ -0,0 +1,53 @@ +bmtrain.nn package +================== + +Submodules +---------- + +bmtrain.nn.column\_parallel\_linear module +------------------------------------------ + +.. automodule:: bmtrain.nn.column_parallel_linear + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.linear module +------------------------ + +.. automodule:: bmtrain.nn.linear + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.parallel\_embedding module +------------------------------------- + +.. automodule:: bmtrain.nn.parallel_embedding + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.parallel\_linear\_func module +---------------------------------------- + +.. automodule:: bmtrain.nn.parallel_linear_func + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.row\_parallel\_linear module +--------------------------------------- + +.. automodule:: bmtrain.nn.row_parallel_linear + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.nn + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/bmtrain.optim.rst b/docs/source/api/bmtrain.optim.rst new file mode 100644 index 00000000..2d47a3dd --- /dev/null +++ b/docs/source/api/bmtrain.optim.rst @@ -0,0 +1,37 @@ +bmtrain.optim package +===================== + +Submodules +---------- + +bmtrain.optim.adam module +------------------------- + +.. automodule:: bmtrain.optim.adam + :members: + :undoc-members: + :show-inheritance: + +bmtrain.optim.adam\_offload module +---------------------------------- + +.. automodule:: bmtrain.optim.adam_offload + :members: + :undoc-members: + :show-inheritance: + +bmtrain.optim.optim\_manager module +----------------------------------- + +.. automodule:: bmtrain.optim.optim_manager + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.optim + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/bmtrain.rst b/docs/source/api/bmtrain.rst index a1d742d6..8445e5f0 100644 --- a/docs/source/api/bmtrain.rst +++ b/docs/source/api/bmtrain.rst @@ -1,58 +1,140 @@ -==================== -bmtrain -==================== +bmtrain package +=============== +Subpackages +----------- -Initialization -========================================== +.. toctree:: + :maxdepth: 4 -.. autofunction:: bmtrain.init_distributed + bmtrain.benchmark + bmtrain.distributed + bmtrain.inspect + bmtrain.loss + bmtrain.lr_scheduler + bmtrain.nccl + bmtrain.nn + bmtrain.optim -Distributed Parameters and Modules -========================================== +Submodules +---------- -.. autoclass:: bmtrain.DistributedParameter - :members: - :show-inheritance: +bmtrain.block\_layer module +--------------------------- + +.. automodule:: bmtrain.block_layer + :members: + :undoc-members: + :show-inheritance: -.. autoclass:: bmtrain.ParameterInitializer - :members: - :show-inheritance: +.. bmtrain.debug module +.. -------------------- + +.. .. automodule:: bmtrain.debug +.. :members: +.. :undoc-members: +.. :show-inheritance: -.. autoclass:: bmtrain.DistributedModule - :members: - :show-inheritance: +bmtrain.global\_var module +-------------------------- + +.. automodule:: bmtrain.global_var + :members: + :undoc-members: + :show-inheritance: + +.. bmtrain.hook\_func module +.. ------------------------- + +.. .. automodule:: bmtrain.hook_func +.. :members: +.. :undoc-members: +.. :show-inheritance: + +bmtrain.init module +------------------- + +.. automodule:: bmtrain.init + :members: + :undoc-members: + :show-inheritance: + +bmtrain.layer module +-------------------- + +.. automodule:: bmtrain.layer + :members: + :undoc-members: + :show-inheritance: + +bmtrain.param\_init module +-------------------------- + +.. automodule:: bmtrain.param_init + :members: + :undoc-members: + :show-inheritance: + +bmtrain.parameter module +------------------------ + +.. automodule:: bmtrain.parameter + :members: DistributedParameter, ParameterInitializer + :undoc-members: + :show-inheritance: + +bmtrain.pipe\_layer module +-------------------------- + +.. automodule:: bmtrain.pipe_layer + :members: PipelineTransformerBlockList + :undoc-members: + :show-inheritance: + +bmtrain.store module +-------------------- + +.. automodule:: bmtrain.store + :members: save, load + :undoc-members: + :show-inheritance: + +bmtrain.synchronize module +-------------------------- + +.. automodule:: bmtrain.synchronize + :members: + :undoc-members: + :show-inheritance: + +bmtrain.utils module +-------------------- + +.. automodule:: bmtrain.utils + :members: + :undoc-members: + :show-inheritance: -.. autoclass:: bmtrain.CheckpointBlock - :members: - :show-inheritance: +bmtrain.wrapper module +---------------------- -.. autoclass:: bmtrain.TransformerBlockList - :members: - :show-inheritance: +.. automodule:: bmtrain.wrapper + :members: BMTrainModelWrapper + :undoc-members: + :show-inheritance: + +bmtrain.zero\_context module +---------------------------- -Methods for Parameters -========================================== +.. automodule:: bmtrain.zero_context + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- -.. autofunction:: bmtrain.init_parameters - -.. autofunction:: bmtrain.grouped_parameters - -.. autofunction:: bmtrain.save - -.. autofunction:: bmtrain.load - -Utilities -========================================== - -.. autofunction:: bmtrain.rank - -.. autofunction:: bmtrain.world_size - -.. autofunction:: bmtrain.print_rank - -.. autofunction:: bmtrain.synchronize - -.. autofunction:: bmtrain.sum_loss - -.. autofunction:: bmtrain.optim_step +.. automodule:: bmtrain + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/inspect.rst b/docs/source/api/inspect.rst deleted file mode 100644 index 955fd347..00000000 --- a/docs/source/api/inspect.rst +++ /dev/null @@ -1,15 +0,0 @@ -==================== -bmtrain.inspect -==================== - -The `bmtrain.inspect` module is a module for debugging and analysis of distributed code. - -We recommend that you use the tools in this module to obtain the parameters and computing results in distributed models. - -.. autofunction:: bmtrain.inspect.inspect_model - -.. autofunction:: bmtrain.inspect.inspect_tensor - -.. autofunction:: bmtrain.inspect.record_tensor - -.. autofunction:: bmtrain.inspect.format_summary diff --git a/docs/source/api/lr_scheduler.rst b/docs/source/api/lr_scheduler.rst deleted file mode 100644 index e05f2651..00000000 --- a/docs/source/api/lr_scheduler.rst +++ /dev/null @@ -1,34 +0,0 @@ -====================== -bmtrain.lr_scheduler -====================== - -The `bmtrain.lr_scheduler` module provides the common learning rate schedulers for big model training. - - -.. autoclass:: bmtrain.lr_scheduler.WarmupLRScheduler - :members: - - -LR Schedulers -========================= - -.. autoclass:: bmtrain.lr_scheduler.NoDecay - :members: - :show-inheritance: - -.. autoclass:: bmtrain.lr_scheduler.Noam - :members: - :show-inheritance: - -.. autoclass:: bmtrain.lr_scheduler.Linear - :members: - :show-inheritance: - -.. autoclass:: bmtrain.lr_scheduler.Exponential - :members: - :show-inheritance: - -.. autoclass:: bmtrain.lr_scheduler.Cosine - :members: - :show-inheritance: - diff --git a/docs/source/api/modules.rst b/docs/source/api/modules.rst new file mode 100644 index 00000000..4350b5d7 --- /dev/null +++ b/docs/source/api/modules.rst @@ -0,0 +1,7 @@ +bmtrain +======= + +.. toctree:: + :maxdepth: 4 + + bmtrain diff --git a/docs/source/api/nccl.rst b/docs/source/api/nccl.rst deleted file mode 100644 index f32e7ba9..00000000 --- a/docs/source/api/nccl.rst +++ /dev/null @@ -1,27 +0,0 @@ -======================= -bmtrain.nccl -======================= - -.. autoclass:: bmtrain.nccl.NCCLCommunicator - :members: - :show-inheritance: - -.. autofunction:: bmtrain.nccl.getUniqueId - -.. autofunction:: bmtrain.nccl.commInitRank - -.. autofunction:: bmtrain.nccl.commDestroy - -.. autofunction:: bmtrain.nccl.allReduce - -.. autofunction:: bmtrain.nccl.broadcast - -.. autofunction:: bmtrain.nccl.reduce - -.. autofunction:: bmtrain.nccl.allGather - -.. autofunction:: bmtrain.nccl.reduceScatter - -.. autofunction:: bmtrain.nccl.groupStart - -.. autofunction:: bmtrain.nccl.groupEnd diff --git a/docs/source/index.rst b/docs/source/index.rst index 5dfeec84..a2ec24f9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,9 +23,13 @@ BMTrain 文档 :caption: Package Reference api/bmtrain.rst - api/nccl.rst - api/inspect.rst - api/lr_scheduler.rst + api/bmtrain.benchmark.rst + api/bmtrain.distributed.rst + api/bmtrain.inspect.rst + api/bmtrain.loss.rst + api/bmtrain.lr_scheduler.rst + api/bmtrain.nccl.rst + api/bmtrain.optim.rst API ==================