Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8-bit LION, take 2 #514

Merged
merged 35 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c7c5ae6
add decoupled lion8b optimizer + tests + builder option + deps
dblalock Aug 10, 2023
cace0b6
Merge branch 'main' into davis/lion8b-v2
dblalock Aug 10, 2023
1a14857
pre-commit fixes
dblalock Aug 11, 2023
e32f1bf
merge upstream
dblalock Aug 11, 2023
16ca215
move lion8b kernels dep to "gpu" extra_deps
dblalock Aug 11, 2023
014ba69
Merge branch 'main' into davis/lion8b-v2
dblalock Aug 11, 2023
f391b7f
move fused error checks to llmfoundry
dblalock Aug 11, 2023
bcf55bf
make precommit + CodeQL happy?
dblalock Aug 11, 2023
6fc1782
disable fsdp param_dtype for low-bit master weights
dblalock Aug 12, 2023
ba0e317
add low-precision master weights option + rm needles .get(..., None)
dblalock Aug 12, 2023
7a55e07
fix missing import in config_utils
dblalock Aug 14, 2023
225ceac
hopefully fix lion8b fsdp checkpointing
dblalock Aug 14, 2023
d53f0e5
pre-commit fixes
dblalock Aug 14, 2023
c9217b6
Merge branch 'main' into davis/lion8b-v2
dblalock Aug 14, 2023
78fbfa9
Merge branch 'main' into davis/lion8b-v2
dblalock Aug 14, 2023
b1125aa
address pr comments
dblalock Aug 14, 2023
71e3f9c
merge upstream
dblalock Aug 14, 2023
476a9ec
fix descent + zero grad tests not being as stringent as intended
dblalock Aug 15, 2023
b87ca31
tiny style change
dblalock Aug 15, 2023
f90a71c
address more pr comments + WIP draft of FSDP checkpointing test
dblalock Aug 15, 2023
0eb3420
partial fix of fsdp state dict test
dblalock Aug 17, 2023
a2a0104
fsdp state dict test passing
dblalock Aug 18, 2023
afd9699
get fsdp state dict test passing with different sharding strategies
dblalock Aug 18, 2023
dd4ccb3
remove state key name indirection as per pr comments
dblalock Aug 18, 2023
abdc6a6
make precommit + pyright happy
dblalock Aug 18, 2023
061e74d
merge main
dblalock Aug 18, 2023
5082966
fix broken merge
dblalock Aug 18, 2023
fbde16b
skip fsdp checkpoint test for torch 1.13.1 since...config classes mis…
dblalock Aug 19, 2023
7adfd57
fix wrong var for model config (manual merge fail)
dblalock Aug 19, 2023
8f84c41
Merge branch 'main' into davis/lion8b-v2
dblalock Aug 21, 2023
8dabf20
Merge branch 'main' into davis/lion8b-v2
dblalock Aug 24, 2023
284a855
print thruputs in thruput test as per pr comments
dblalock Aug 24, 2023
7e4c11f
merge upstream
dblalock Aug 24, 2023
b81f7bb
Merge branch 'main' into davis/lion8b-v2
dblalock Aug 24, 2023
eae355f
merge upstream
dblalock Aug 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion llmfoundry/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@

from llmfoundry.optim.adaptive_lion import DecoupledAdaLRLion, DecoupledClipLion
from llmfoundry.optim.lion import DecoupledLionW
from llmfoundry.optim.lion8b import DecoupledLionW_8bit

__all__ = ['DecoupledLionW', 'DecoupledClipLion', 'DecoupledAdaLRLion']
__all__ = [
'DecoupledLionW', 'DecoupledLionW_8bit', 'DecoupledClipLion',
'DecoupledAdaLRLion'
]
429 changes: 429 additions & 0 deletions llmfoundry/optim/lion8b.py
vchiley marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
LayerFreezing, MonolithicCheckpointSaver,
ScheduledGarbageCollector)
from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion,
DecoupledLionW)
DecoupledLionW, DecoupledLionW_8bit)


def build_callback(name: str, kwargs: Dict[str, Any]):
Expand Down Expand Up @@ -98,6 +98,8 @@ def build_optimizer(model: torch.nn.Module, name: str,
return DecoupledClipLion(model.parameters(), **optimizer_config)
elif name == 'adalr_lion':
return DecoupledAdaLRLion(model.parameters(), **optimizer_config)
elif name == 'decoupled_lionw_8b':
return DecoupledLionW_8bit(model.parameters(), **optimizer_config)
else:
raise ValueError(f'Not sure how to build optimizer: {name}')

Expand Down
21 changes: 20 additions & 1 deletion llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextlib
import math
import warnings
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Mapping, Optional, Union

from composer.utils import dist
from omegaconf import DictConfig, ListConfig
Expand Down Expand Up @@ -116,6 +116,25 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]):
# Set defaults for mixed initialization
fsdp_config.setdefault('use_orig_params', False)
fsdp_config.setdefault('load_monolith_rank0_only', True)

# no mixed precision needed for weights when they're already 16 bits
master_dtype = model_cfg.get('master_weights_dtype')
small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16',
dblalock marked this conversation as resolved.
Show resolved Hide resolved
'amp_bf16')
if fsdp_config and master_dtype in small_dtypes:
reduce_dtype = None
buffer_dtype = None
mixed_precision = fsdp_config.get('mixed_precision')
if isinstance(mixed_precision, Mapping):
dblalock marked this conversation as resolved.
Show resolved Hide resolved
reduce_dtype = mixed_precision.get('reduce_dtype')
buffer_dtype = mixed_precision.get('buffer_dtype')
fsdp_config['mixed_precision'] = {
'param_dtype': None,
'reduce_dtype': reduce_dtype,
'buffer_dtype': buffer_dtype,
'keep_low_precision_grads': True,
}

return init_context


Expand Down
5 changes: 5 additions & 0 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,10 @@ def main(cfg: DictConfig):
print_trainable_parameters(model) # should not be 100%
else: # standard model
model = build_composer_model(model_config, tokenizer)
if model_config.get('master_weights_dtype') in ('bf16', 'bfloat16'):
model = model.to(dtype=torch.bfloat16)
elif model_config.get('master_weights_dtype') in ('f16', 'float16'):
model = model.to(dtype=torch.float16)

# Log number of parameters
n_params = sum(p.numel() for p in model.parameters())
Expand Down Expand Up @@ -515,5 +519,6 @@ def main(cfg: DictConfig):
yaml_cfg = om.load(f)
cli_cfg = om.from_cli(args_list)
cfg = om.merge(yaml_cfg, cli_cfg)
om.resolve(cfg)
assert isinstance(cfg, DictConfig)
main(cfg)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@

extra_deps['gpu'] = [
'flash-attn==v1.0.3.post0',
'mosaicml-turbo>=0.0.2,<0.1',
# PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected]#subdirectory=csrc/xentropy',
]
Expand Down
Loading
Loading