Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable param group configuration in llm-foundry #760

Merged
merged 11 commits into from
Nov 29, 2023
22 changes: 12 additions & 10 deletions llmfoundry/optim/lion8b.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict, Iterable, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -58,15 +58,17 @@ class DecoupledLionW_8bit(torch.optim.Optimizer):
device, or b) step() is executed on a non-CUDA parameter.
"""

def __init__(self,
params: Iterable[torch.Tensor],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0,
quantize: bool = True,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True): # XXX this flag is mostly for testing...
def __init__(
self,
params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]],
vchiley marked this conversation as resolved.
Show resolved Hide resolved
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0,
quantize: bool = True,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True, # XXX this flag is mostly for testing...
):

if lr < 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
Expand Down
115 changes: 109 additions & 6 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import functools
import logging
import os
import re
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from composer import algorithms
Expand Down Expand Up @@ -155,18 +158,118 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]) -> Algorithm:
raise ValueError(f'Not sure how to build algorithm: {name}')


def _extract_param_groups(
model: torch.nn.Module,
optimizer_config: Dict[str, Any],
) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]:
"""Extracts parameter groups defined in the optimizer config.

The optimizer_config defines the optimizer args. It can additionally have key
`disable_grad` which is a string or list of strings. If a string matches a
parameter name, then that parameter will have `requires_grad=False`. This is
useful for freezing parameters. It can additionally have a key
`param_groups` which is a list of dicts. In this dict, key `param_str_match`
defines a string; if a parameter name contains this string, then it will be
in this parameter group. This is useful for grouping parameters together.
The dict can also contain any other key that is a valid optimizer arg.
Note: to handle name overlap conflicts, params are assigned to parameter
groups and added to `param_groups` in the order that `param_str_match` appear
in `param_groups`.

Usage
To disable gradient for all parameters that contain the string "norm" or "bias":
```
optimizer_config: {
"name": "decoupled_lionw",
"lr": 1e-3,
"weight_decay": 1e-2,
"betas": [0.9, 0.999],
"eps": 1e-8,
"disable_grad": ["norm", "bias"]
}
```

To create modify the optimizer parameters for all parameters that contain the
vchiley marked this conversation as resolved.
Show resolved Hide resolved
string "norm" and "bias" separately:
```
optimizer_config: {
"name": "decoupled_lionw",
"lr": 1e-3,
"weight_decay": 1e-2,
"betas": [0.9, 0.999],
"eps": 1e-8,
"param_groups": [
{
"param_str_match": "norm",
"lr": 1e-4,
"weight_decay": 0.0,
},
{
"param_str_match": "bias",
"lr": 5e-4,
"weight_decay": 0.0,
},
],
}
```

Args:
model (torch.nn.Module): model to extract parameters from
optimizer_config (Dict[str, Any]): optimizer config

Returns:
Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of
torch.Tensor's or dict's. Specifies what Tensors should be optimized.
vchiley marked this conversation as resolved.
Show resolved Hide resolved
"""
if 'disable_grad' in optimizer_config.keys():
str_matches = optimizer_config.pop('disable_grad')
if isinstance(str_matches, str):
str_matches = [str_matches]
for str_match in str_matches:
for n, p in model.named_parameters():
if re.search(str_match, n):
p.requires_grad = False
log.debug(f'Setting `{n}.requires_grad = False`.')

if 'param_groups' in optimizer_config.keys():
params = []
param_dict = OrderedDict((n, p) for n, p in model.named_parameters())

for param_group_config in optimizer_config['param_groups']:
str_match = param_group_config.pop('param_str_match')
filter_fn = functools.partial(re.search, str_match)
param_names = [n for n in param_dict.keys() if filter_fn(n)]
group_params = {'params': [param_dict.pop(n) for n in param_names]}
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
group_params.update(param_group_config)

params.append(group_params)

optimizer_config.pop('param_groups')

params.insert(0, {'params': param_dict.values()})

log.debug(f'Optimizer param_groups: {params}.')
vchiley marked this conversation as resolved.
Show resolved Hide resolved

return params

return model.parameters()


def build_optimizer(model: torch.nn.Module, name: str,
optimizer_config: Dict[str, Any]) -> Optimizer:

params = _extract_param_groups(model, optimizer_config)

if name == 'decoupled_adamw':
return DecoupledAdamW(model.parameters(), **optimizer_config)
return DecoupledAdamW(params, **optimizer_config)
elif name == 'decoupled_lionw':
return DecoupledLionW(model.parameters(), **optimizer_config)
return DecoupledLionW(params, **optimizer_config)
elif name == 'clip_lion':
return DecoupledClipLion(model.parameters(), **optimizer_config)
return DecoupledClipLion(params, **optimizer_config)
elif name == 'adalr_lion':
return DecoupledAdaLRLion(model.parameters(), **optimizer_config)
return DecoupledAdaLRLion(params, **optimizer_config)
elif name == 'decoupled_lionw_8b':
return DecoupledLionW_8bit(model.parameters(), **optimizer_config)
return DecoupledLionW_8bit(params, **optimizer_config)
else:
raise ValueError(f'Not sure how to build optimizer: {name}')

Expand Down
82 changes: 80 additions & 2 deletions tests/test_builders.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import re
import unittest.mock as mock
from typing import Union
from copy import deepcopy
from typing import Any, Dict, Union

import pytest
import torch
import torch.nn as nn
from composer.callbacks import Generate
from omegaconf import OmegaConf as om
from transformers import PreTrainedTokenizerBase

from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
from llmfoundry.utils.builders import build_callback, build_tokenizer
from llmfoundry.utils.builders import (build_callback, build_optimizer,
build_tokenizer)


@pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [
Expand Down Expand Up @@ -110,3 +115,76 @@ def test_build_hf_checkpointer_callback():
assert isinstance(kwargs['mlflow_logging_config'], dict)
assert isinstance(kwargs['mlflow_logging_config']['metadata'], dict)
assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict


class _DummyModule(nn.Module):

def __init__(self, device: str = 'cpu', dtype: torch.dtype = torch.float32):
super().__init__()
self.linear0 = nn.Linear(4, 3, device=device, dtype=dtype)
self.norm0 = nn.LayerNorm(3, device=device, dtype=dtype)
self.linear1 = nn.Linear(3, 5, device=device, dtype=dtype)

def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore
return self.linear1(self.norm0(self.linear0(x)))


@pytest.mark.parametrize('name, optimizer_config', [
('decoupled_adamw', {}),
('decoupled_lionw', {}),
('clip_lion', {}),
('adalr_lion', {}),
pytest.param('decoupled_lionw_8b', {}, marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('opt_additional_config', [
vchiley marked this conversation as resolved.
Show resolved Hide resolved
{
'disable_grad': 'norm'
},
{
'disable_grad': ['norm', 'bias']
},
{
'param_groups': [{
'param_str_match': 'norm',
'lr': 1e-9,
'weight_decay': 0.0,
},]
},
{
'param_groups': [{
'param_str_match': 'norm',
'lr': 1e-4,
'weight_decay': 0.0,
},],
'disable_grad': ['bias'],
},
])
def test_build_optimizer(name: str, optimizer_config: Dict[str, Any],
opt_additional_config: Dict[str, Any]):
model = _DummyModule()
optimizer_config.update(deepcopy(opt_additional_config))
optimizer = build_optimizer(model, name, optimizer_config)

if 'disable_grad' in opt_additional_config.keys():
disable_grad = opt_additional_config['disable_grad']
if isinstance(disable_grad, str):
disable_grad = [disable_grad]
for n, p in model.named_parameters():
for k in disable_grad:
if re.search(k, n):
assert not p.requires_grad

if 'param_groups' in opt_additional_config.keys():
for param_group_config, param_group in zip(
opt_additional_config['param_groups'],
optimizer.param_groups[1:]):
param_group_config = deepcopy(param_group_config)
param_str_match = param_group_config.pop('param_str_match')

for k, v in param_group_config.items():
assert param_group[k] == v

param_ids = [id(p) for p in param_group['params']]
for n, p in model.named_parameters():
if re.search(param_str_match, n):
assert id(p) in param_ids
Loading