Skip to content

Commit

Permalink
Merge branch 'main' into hf-flash-attn-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 9, 2023
2 parents ad6fd55 + aa2ba9f commit c3ad4b4
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 20 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ jobs:
base_image: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
- name: '2.0.1_cu118'
base_image: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04
- name: '2.1.0_cu121'
base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04

steps:
- name: Maximize Build Space on Worker
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ jobs:
container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
markers: 'not gpu'
pytest_command: 'coverage run -m pytest'
- name: 'cpu-2.1.0'
container: mosaicml/pytorch:2.1.0_cpu-python3.10-ubuntu20.04
markers: 'not gpu'
pytest_command: 'coverage run -m pytest'
name: ${{ matrix.name }}
if: github.repository_owner == 'mosaicml'
with:
Expand Down
6 changes: 5 additions & 1 deletion .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ jobs:
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
- name: 'gpu-2.0.1'
container: mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04
container: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
- name: 'gpu-2.1.0'
container: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
name: ${{ matrix.name }}
Expand Down
13 changes: 9 additions & 4 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,7 @@ def __init__(self, om_model_config: Union[DictConfig,
nn.Module],
tokenizer: PreTrainedTokenizerBase):
# set up training and eval metrics
train_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
]
train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()]
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
Expand All @@ -90,6 +87,9 @@ def __init__(self, om_model_config: Union[DictConfig,
'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.'
)

if not om_model_config.get('use_train_metrics', True):
train_metrics = []

# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
Expand All @@ -107,6 +107,7 @@ def __init__(self, om_model_config: Union[DictConfig,
)

attr = getattr(config, k)
# attempt to disallow typos in nested configs
if isinstance(attr, Mapping):
extra_keys = [
_k for _k in v.keys() if _k not in attr.keys()
Expand All @@ -118,6 +119,10 @@ def __init__(self, om_model_config: Union[DictConfig,
f'Expected (a subset of) keys: {list(attr.keys())}.'
)
getattr(config, k).update(v)
# necessary case to allow for rope_scaling to be overriden in llama config
elif attr is None and isinstance(v, Mapping):
setattr(config, k, {})
getattr(config, k).update(v)
else:
setattr(config, k, v)

Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,9 @@ def __init__(
hf_config = MPTConfig.from_dict(resolved_om_model_config)
model = MPTForCausalLM(hf_config)

train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()]
use_train_metrics = om_model_config.get('use_train_metrics', True)
train_metrics = [LanguageCrossEntropy(),
LanguagePerplexity()] if use_train_metrics else []
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
Expand Down
27 changes: 21 additions & 6 deletions llmfoundry/optim/lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class DecoupledLionW_8bit(torch.optim.Optimizer):
by retaining information across optimizer steps.
Raises:
NotImplemenetedError - If any of `quantize`, `compress_state_dict`,
NotImplementedError - If any of `quantize`, `compress_state_dict`,
or `error_correction` are `True` and either a) there is no CUDA
device, or b) step() is executed on a non-CUDA parameter.
"""
Expand All @@ -67,6 +67,7 @@ def __init__(self,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True): # XXX this flag is mostly for testing...

if lr < 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
if not 0.0 <= betas[0] <= 1.0:
Expand Down Expand Up @@ -131,11 +132,19 @@ def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None:
mom, try_quantize=self._quantize)
need_errs = (p.dtype != torch.float32) and self._error_correction
if state.get('errors') is None and need_errs:
state['errors'] = torch.zeros(p.shape,
dtype=torch.uint8,
device=p.device)
numel = p.numel()
numel += numel % 2 # ensure even number of bytes
errors = torch.zeros(numel, dtype=torch.uint8, device=p.device)
# as of torch 2.1, FSDP can't shard ints for no reason
state['errors'] = errors.view(torch.bfloat16)
decay_factor = hparams['weight_decay']
decay_factor *= hparams['lr'] / hparams['initial_lr']
errors: Optional[torch.Tensor] = None
if 'errors' in state:
errors = state['errors']
assert errors is not None # pyright
errors = errors.view(dtype=torch.uint8)
errors = errors[:p.numel()].view(p.shape) # strip padding + reshape
_lion8b_step(momentums=state['exp_avg'],
weights=p,
grads=p.grad,
Expand All @@ -144,7 +153,7 @@ def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None:
lr=hparams['lr'],
weight_decay=decay_factor,
fused=hparams['fused'],
errors=state.get('errors'))
errors=errors)

def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None:
# we override this function to quantize optimizer states when
Expand All @@ -166,7 +175,8 @@ def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None:
# we need to cast back to the correct dtype since optimizer
# load_state_dict casts to param dtype for fp params; see
# https://github.com/pytorch/pytorch/blob/a25eee1d77d93079614fab3ea4ac66e64fb2343b/torch/optim/optimizer.py#L626C7-L626C7 # noqa
errs = param_state['errors'].to(dtype=torch.uint8)
errs = param_state['errors'].to(dtype=torch.uint8).view(
torch.bfloat16)
new_state['errors'] = errs
opt_state[param_id] = new_state
super().__setstate__(state)
Expand All @@ -192,6 +202,11 @@ def state_dict(self):
qtensor.state_dict(
name='exp_avg',
allow_quantized=self._compress_state_dict))
if 'errors' in param_state:
# fsdp apparently needs the states to be the same shape
# as the params
param_state['errors'] = param_state['errors'].view(
torch.uint8).to(dtype=torch.bfloat16)
opt_state[param_id] = param_state
return d

Expand Down
1 change: 1 addition & 0 deletions mcli/mcli-llama2-finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: -e .[gpu]
ssh_clone: false # Should be true if using a private repo

Expand Down
4 changes: 2 additions & 2 deletions mcli/mcli-openai-eval.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: # use your branch
# git_commit: 29d65cc26853c09f6de7542978056ddb0b07e98c # OR use your commit hash
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: -e ".[gpu,openai]"
ssh_clone: false # Should be true if using a private repo

Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-pretokenize-oci-upload.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ integrations:
- oci-cli==3.23.2
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.2.0
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: '.'
ssh_clone: false # Should be true if using a private repo
Expand Down
5 changes: 4 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,12 @@ def main(cfg: DictConfig) -> Trainer:
and save_folder is not None \
and not save_overwrite \
and not save_weights_only:
autoresume_default = True

if cfg.get('autoresume') is None and autoresume_default:
print('As run_name, save_folder, and save_latest_filename are set, \
changing autoresume default to True...')
autoresume_default = True

autoresume: bool = pop_config(cfg,
'autoresume',
must_exist=False,
Expand Down
28 changes: 28 additions & 0 deletions tests/test_hf_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import os
import tempfile
from copy import deepcopy
from pathlib import Path
Expand Down Expand Up @@ -139,3 +140,30 @@ def test_hf_config_override(
assert getattr(hf_model.config, k)[_k] == _v
else:
assert getattr(hf_model.config, k) == v


@pytest.mark.skipif('HUGGING_FACE_HUB_TOKEN' not in os.environ,
reason='CI does not have access to llama2')
def test_rope_scaling_override():
model_cfg = {
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': 'meta-llama/Llama-2-7b-hf',
'config_overrides': {
'num_hidden_layers': 2,
'hidden_size': 32,
'intermediate_size': 64,
'rope_scaling': {
'type': 'dynamic',
'factor': 0.5
}
},
'use_auth_token': True,
'pretrained': False,
'init_device': 'cpu',
}
model_cfg = om.create(model_cfg)

model = COMPOSER_MODEL_REGISTRY[model_cfg.name](model_cfg, tokenizer=None)
# This would error if the config isn't parsed into a proper dictionary
model.get_metadata()
assert model.config.rope_scaling == {'type': 'dynamic', 'factor': 0.5}
8 changes: 4 additions & 4 deletions tests/test_lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ class _DummyModule(nn.Module):
def __init__(self, device: str, dtype: torch.dtype):
super().__init__()
self.linear0 = nn.Linear(4, 3, device=device, dtype=dtype)
self.linear1 = nn.Linear(3, 4, device=device, dtype=dtype)
self.linear1 = nn.Linear(3, 5, device=device, dtype=dtype)

def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore
return self.linear1(self.linear0(x))
Expand Down Expand Up @@ -416,7 +416,7 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool,

torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp
if not dist.is_initialized():
dist.init_process_group()
dist.init_process_group(backend='nccl')
assert dist.get_world_size() >= 2, 'Misconfigured test run!'

mod = FSDP(_DummyModule(device=device, dtype=dtype))
Expand Down Expand Up @@ -460,7 +460,7 @@ def _set_state_dict_type(model: nn.Module):

# load state dict into the new optimizer
opt_state_dict_slice = FSDP.optim_state_dict_to_load(
opt_state_dict, mod_new, opt_new)
optim_state_dict=opt_state_dict, model=mod_new, optim=opt_new)
opt_new.load_state_dict(opt_state_dict_slice)

new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new)
Expand All @@ -481,7 +481,7 @@ def _set_state_dict_type(model: nn.Module):

assert mom_orig.shape == mom_new.shape
assert mom_orig.dtype == mom_new.dtype
if use_errors:
if use_errors and (dtype != torch.float32):
errs_orig = d_orig['errors']
errs_new = d_new['errors']
assert errs_orig.shape == errs_new.shape
Expand Down

0 comments on commit c3ad4b4

Please sign in to comment.