Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8-bit LION, take 2 #514

Merged
merged 35 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c7c5ae6
add decoupled lion8b optimizer + tests + builder option + deps
dblalock Aug 10, 2023
cace0b6
Merge branch 'main' into davis/lion8b-v2
dblalock Aug 10, 2023
1a14857
pre-commit fixes
dblalock Aug 11, 2023
e32f1bf
merge upstream
dblalock Aug 11, 2023
16ca215
move lion8b kernels dep to "gpu" extra_deps
dblalock Aug 11, 2023
014ba69
Merge branch 'main' into davis/lion8b-v2
dblalock Aug 11, 2023
f391b7f
move fused error checks to llmfoundry
dblalock Aug 11, 2023
bcf55bf
make precommit + CodeQL happy?
dblalock Aug 11, 2023
6fc1782
disable fsdp param_dtype for low-bit master weights
dblalock Aug 12, 2023
ba0e317
add low-precision master weights option + rm needles .get(..., None)
dblalock Aug 12, 2023
7a55e07
fix missing import in config_utils
dblalock Aug 14, 2023
225ceac
hopefully fix lion8b fsdp checkpointing
dblalock Aug 14, 2023
d53f0e5
pre-commit fixes
dblalock Aug 14, 2023
c9217b6
Merge branch 'main' into davis/lion8b-v2
dblalock Aug 14, 2023
78fbfa9
Merge branch 'main' into davis/lion8b-v2
dblalock Aug 14, 2023
b1125aa
address pr comments
dblalock Aug 14, 2023
71e3f9c
merge upstream
dblalock Aug 14, 2023
476a9ec
fix descent + zero grad tests not being as stringent as intended
dblalock Aug 15, 2023
b87ca31
tiny style change
dblalock Aug 15, 2023
f90a71c
address more pr comments + WIP draft of FSDP checkpointing test
dblalock Aug 15, 2023
0eb3420
partial fix of fsdp state dict test
dblalock Aug 17, 2023
a2a0104
fsdp state dict test passing
dblalock Aug 18, 2023
afd9699
get fsdp state dict test passing with different sharding strategies
dblalock Aug 18, 2023
dd4ccb3
remove state key name indirection as per pr comments
dblalock Aug 18, 2023
abdc6a6
make precommit + pyright happy
dblalock Aug 18, 2023
061e74d
merge main
dblalock Aug 18, 2023
5082966
fix broken merge
dblalock Aug 18, 2023
fbde16b
skip fsdp checkpoint test for torch 1.13.1 since...config classes mis…
dblalock Aug 19, 2023
7adfd57
fix wrong var for model config (manual merge fail)
dblalock Aug 19, 2023
8f84c41
Merge branch 'main' into davis/lion8b-v2
dblalock Aug 21, 2023
8dabf20
Merge branch 'main' into davis/lion8b-v2
dblalock Aug 24, 2023
284a855
print thruputs in thruput test as per pr comments
dblalock Aug 24, 2023
7e4c11f
merge upstream
dblalock Aug 24, 2023
b81f7bb
Merge branch 'main' into davis/lion8b-v2
dblalock Aug 24, 2023
eae355f
merge upstream
dblalock Aug 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion llmfoundry/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@

from llmfoundry.optim.adaptive_lion import DecoupledAdaLRLion, DecoupledClipLion
from llmfoundry.optim.lion import DecoupledLionW
from llmfoundry.optim.lion8b import DecoupledLionW_8bit

__all__ = ['DecoupledLionW', 'DecoupledClipLion', 'DecoupledAdaLRLion']
__all__ = [
'DecoupledLionW', 'DecoupledLionW_8bit', 'DecoupledClipLion',
'DecoupledAdaLRLion'
]
417 changes: 417 additions & 0 deletions llmfoundry/optim/lion8b.py
vchiley marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
LayerFreezing, MonolithicCheckpointSaver,
ScheduledGarbageCollector)
from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion,
DecoupledLionW)
DecoupledLionW, DecoupledLionW_8bit)


def build_callback(name: str, kwargs: Dict[str, Any]):
Expand Down Expand Up @@ -115,6 +115,10 @@ def build_optimizer(cfg: DictConfig, model: torch.nn.Module):
timeout=cfg.timeout,
lr_penalty=cfg.lr_penalty,
min_scale=cfg.min_scale)
elif cfg.name.lower() == 'decoupled_lionw_8b':
# str() cast is just for pyright
kwargs = {str(k): v for k, v in cfg.items() if k != 'name'}
vchiley marked this conversation as resolved.
Show resolved Hide resolved
return DecoupledLionW_8bit(model.parameters(), **kwargs)
else:
raise ValueError(f'Not sure how to build optimizer: {cfg.name}')

Expand Down
21 changes: 20 additions & 1 deletion llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextlib
import math
import warnings
from typing import Dict, Optional, Union
from typing import Dict, Mapping, Optional, Union

from composer.utils import dist
from omegaconf import DictConfig
Expand Down Expand Up @@ -86,6 +86,25 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]):
# Set defaults for mixed initialization
fsdp_config.setdefault('use_orig_params', False)
fsdp_config.setdefault('load_monolith_rank0_only', True)

# no mixed precision needed for weights when they're already 16 bits
master_dtype = model_cfg.get('master_weights_dtype')
small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16',
dblalock marked this conversation as resolved.
Show resolved Hide resolved
'amp_bf16')
if fsdp_config and master_dtype in small_dtypes:
reduce_dtype = None
buffer_dtype = None
mixed_precision = fsdp_config.get('mixed_precision')
if isinstance(mixed_precision, Mapping):
dblalock marked this conversation as resolved.
Show resolved Hide resolved
reduce_dtype = mixed_precision.get('reduce_dtype')
buffer_dtype = mixed_precision.get('buffer_dtype')
fsdp_config['mixed_precision'] = {
'param_dtype': None,
'reduce_dtype': reduce_dtype,
'buffer_dtype': buffer_dtype,
'keep_low_precision_grads': True,
}

return init_context


Expand Down
27 changes: 16 additions & 11 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,16 @@ def main(cfg: DictConfig):
cfg = update_batch_size_info(cfg)

# Read FSDP Config as a dict
fsdp_config = cfg.get('fsdp_config', None)
fsdp_config = om.to_container(fsdp_config,
resolve=True) if fsdp_config else None
assert isinstance(fsdp_config, Dict) or fsdp_config is None
if dist.get_world_size() == 1 and fsdp_config is not None:
warnings.warn(
'FSDP is not applicable for single-GPU training. Reverting to DDP.')
cfg.pop('fsdp_config')
fsdp_config = None
fsdp_config = cfg.get('fsdp_config')
if fsdp_config is not None:
fsdp_config = om.to_container(fsdp_config, resolve=True)
assert isinstance(fsdp_config, Dict)
if dist.get_world_size() == 1:
warnings.warn(
'FSDP is not applicable for single-GPU training. Reverting to DDP.'
)
cfg.pop('fsdp_config')
fsdp_config = None

init_context = process_init_device(cfg.model, fsdp_config)

Expand All @@ -212,13 +213,16 @@ def main(cfg: DictConfig):
# Build Model
print('Initializing model...')
with init_context:
if cfg.get('lora',
None) is not None: # frozen model + trainable lora modules
if cfg.get('lora') is not None: # frozen model + trainable lora modules
model: ComposerHFCausalLM = build_composer_peft_model(
cfg.model, cfg.lora, tokenizer)
print_trainable_parameters(model) # should not be 100%
else: # standard model
model = build_composer_model(cfg.model, tokenizer)
if cfg.model.get('master_weights_dtype') in ('bf16', 'bfloat16'):
model = model.to(dtype=torch.bfloat16)
elif cfg.model.get('master_weights_dtype') in ('f16', 'float16'):
model = model.to(dtype=torch.float16)
cfg.n_params = sum(p.numel() for p in model.parameters())
print(f'{cfg.n_params=:.2e}')

Expand Down Expand Up @@ -342,5 +346,6 @@ def main(cfg: DictConfig):
yaml_cfg = om.load(f)
cli_cfg = om.from_cli(args_list)
cfg = om.merge(yaml_cfg, cli_cfg)
om.resolve(cfg)
assert isinstance(cfg, DictConfig)
main(cfg)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@

extra_deps['gpu'] = [
'flash-attn==v1.0.3.post0',
'mosaicml-turbo>=0.0.2,<0.1',
# PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected]#subdirectory=csrc/xentropy',
]
Expand Down
Loading
Loading