Skip to content

Commit

Permalink
pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dblalock committed Aug 14, 2023
1 parent 225ceac commit d53f0e5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
25 changes: 15 additions & 10 deletions llmfoundry/optim/lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,13 @@ def __init__(self,
self._error_correction = error_correction
self._compress_state_dict = compress_state_dict

defaults = {'lr': lr,
'initial_lr': lr,
'betas': betas,
'weight_decay': weight_decay,
'fused': _fused}
defaults = {
'lr': lr,
'initial_lr': lr,
'betas': betas,
'weight_decay': weight_decay,
'fused': _fused
}
super().__init__(params, defaults)

@torch.no_grad()
Expand Down Expand Up @@ -182,9 +184,10 @@ def state_dict(self):
if _KEY_MOMENTUM in param_state: # true if we've taken any steps
qtensor = param_state.pop(_KEY_MOMENTUM)
assert isinstance(qtensor, _MaybeQuantizedTensor) # pyright
param_state.update(qtensor.state_dict(
name=_KEY_MOMENTUM,
allow_quantized=self._compress_state_dict))
param_state.update(
qtensor.state_dict(
name=_KEY_MOMENTUM,
allow_quantized=self._compress_state_dict))
opt_state[param_id] = param_state
return d

Expand Down Expand Up @@ -224,8 +227,10 @@ def state_dict(self,
if self.is_quantized() and allow_quantized:
assert self.quantized is not None # pyright
assert self.scales is not None # pyright
return {f'{name}::quantized': self.quantized,
f'{name}::scales': self.scales}
return {
f'{name}::quantized': self.quantized,
f'{name}::scales': self.scales
}
return {name: self.materialize().to(dtype=torch.bfloat16)}

def load_state_dict(self, d: Dict[str, torch.Tensor], name: str) -> None:
Expand Down
5 changes: 3 additions & 2 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextlib
import math
import warnings
from typing import Dict, Optional, Union, Mapping
from typing import Dict, Mapping, Optional, Union

from composer.utils import dist
from omegaconf import DictConfig
Expand Down Expand Up @@ -89,7 +89,8 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]):

# no mixed precision needed for weights when they're already 16 bits
master_dtype = model_cfg.get('master_weights_dtype')
small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16', 'amp_bf16')
small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16',
'amp_bf16')
if fsdp_config and master_dtype in small_dtypes:
reduce_dtype = None
buffer_dtype = None
Expand Down
3 changes: 2 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def main(cfg: DictConfig):
assert isinstance(fsdp_config, Dict)
if dist.get_world_size() == 1:
warnings.warn(
'FSDP is not applicable for single-GPU training. Reverting to DDP.')
'FSDP is not applicable for single-GPU training. Reverting to DDP.'
)
cfg.pop('fsdp_config')
fsdp_config = None

Expand Down

0 comments on commit d53f0e5

Please sign in to comment.