From d53f0e58c2b69aeead7314962bd7e61a80d09aed Mon Sep 17 00:00:00 2001 From: root Date: Mon, 14 Aug 2023 05:49:10 +0000 Subject: [PATCH] pre-commit fixes --- llmfoundry/optim/lion8b.py | 25 +++++++++++++++---------- llmfoundry/utils/config_utils.py | 5 +++-- scripts/train/train.py | 3 ++- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 16cebf77b3..c22689a212 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -89,11 +89,13 @@ def __init__(self, self._error_correction = error_correction self._compress_state_dict = compress_state_dict - defaults = {'lr': lr, - 'initial_lr': lr, - 'betas': betas, - 'weight_decay': weight_decay, - 'fused': _fused} + defaults = { + 'lr': lr, + 'initial_lr': lr, + 'betas': betas, + 'weight_decay': weight_decay, + 'fused': _fused + } super().__init__(params, defaults) @torch.no_grad() @@ -182,9 +184,10 @@ def state_dict(self): if _KEY_MOMENTUM in param_state: # true if we've taken any steps qtensor = param_state.pop(_KEY_MOMENTUM) assert isinstance(qtensor, _MaybeQuantizedTensor) # pyright - param_state.update(qtensor.state_dict( - name=_KEY_MOMENTUM, - allow_quantized=self._compress_state_dict)) + param_state.update( + qtensor.state_dict( + name=_KEY_MOMENTUM, + allow_quantized=self._compress_state_dict)) opt_state[param_id] = param_state return d @@ -224,8 +227,10 @@ def state_dict(self, if self.is_quantized() and allow_quantized: assert self.quantized is not None # pyright assert self.scales is not None # pyright - return {f'{name}::quantized': self.quantized, - f'{name}::scales': self.scales} + return { + f'{name}::quantized': self.quantized, + f'{name}::scales': self.scales + } return {name: self.materialize().to(dtype=torch.bfloat16)} def load_state_dict(self, d: Dict[str, torch.Tensor], name: str) -> None: diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index e8d130230c..79c9fe8011 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -4,7 +4,7 @@ import contextlib import math import warnings -from typing import Dict, Optional, Union, Mapping +from typing import Dict, Mapping, Optional, Union from composer.utils import dist from omegaconf import DictConfig @@ -89,7 +89,8 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # no mixed precision needed for weights when they're already 16 bits master_dtype = model_cfg.get('master_weights_dtype') - small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16', 'amp_bf16') + small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16', + 'amp_bf16') if fsdp_config and master_dtype in small_dtypes: reduce_dtype = None buffer_dtype = None diff --git a/scripts/train/train.py b/scripts/train/train.py index 611eabb155..2602bf8b43 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -200,7 +200,8 @@ def main(cfg: DictConfig): assert isinstance(fsdp_config, Dict) if dist.get_world_size() == 1: warnings.warn( - 'FSDP is not applicable for single-GPU training. Reverting to DDP.') + 'FSDP is not applicable for single-GPU training. Reverting to DDP.' + ) cfg.pop('fsdp_config') fsdp_config = None