Skip to content

Commit

Permalink
enable param group configuration in llm-foundry
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Nov 22, 2023
1 parent f8ee914 commit 26bbc57
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 16 deletions.
22 changes: 12 additions & 10 deletions llmfoundry/optim/lion8b.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict, Iterable, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -58,15 +58,17 @@ class DecoupledLionW_8bit(torch.optim.Optimizer):
device, or b) step() is executed on a non-CUDA parameter.
"""

def __init__(self,
params: Iterable[torch.Tensor],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0,
quantize: bool = True,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True): # XXX this flag is mostly for testing...
def __init__(
self,
params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0,
quantize: bool = True,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True, # XXX this flag is mostly for testing...
):

if lr < 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
Expand Down
53 changes: 47 additions & 6 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import logging
import os
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from composer import algorithms
Expand Down Expand Up @@ -155,18 +156,58 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]) -> Algorithm:
raise ValueError(f'Not sure how to build algorithm: {name}')


def extract_param_groups(
model: torch.nn.Module,
optimizer_config: Dict[str, Any],
) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]:

if 'disable_grad' in optimizer_config.keys():
str_match = optimizer_config.pop('disable_grad')
if isinstance(str_match, str):
str_match = [str_match]
for _str_match in str_match:
for n, p in model.named_parameters():
if n in _str_match:
p.requires_grad = False

if 'param_groups' in optimizer_config.keys():
params = []
param_dict = OrderedDict((n, p) for n, p in model.named_parameters())

for param_group_config in optimizer_config['param_groups']:
str_match = param_group_config.pop('param_str_match')
group_param_names = [n for n in param_dict.keys() if str_match in n]
_params = []
for n in group_param_names:
_params.append(param_dict.pop(n))
group_params = {'params': _params}
group_params.update(param_group_config)

params.append(group_params)

optimizer_config.pop('param_groups')

params.insert(0, {'params': param_dict.values()})
return params

return model.parameters()


def build_optimizer(model: torch.nn.Module, name: str,
optimizer_config: Dict[str, Any]) -> Optimizer:

params = extract_param_groups(model, optimizer_config)

if name == 'decoupled_adamw':
return DecoupledAdamW(model.parameters(), **optimizer_config)
return DecoupledAdamW(params, **optimizer_config)
elif name == 'decoupled_lionw':
return DecoupledLionW(model.parameters(), **optimizer_config)
return DecoupledLionW(params, **optimizer_config)
elif name == 'clip_lion':
return DecoupledClipLion(model.parameters(), **optimizer_config)
return DecoupledClipLion(params, **optimizer_config)
elif name == 'adalr_lion':
return DecoupledAdaLRLion(model.parameters(), **optimizer_config)
return DecoupledAdaLRLion(params, **optimizer_config)
elif name == 'decoupled_lionw_8b':
return DecoupledLionW_8bit(model.parameters(), **optimizer_config)
return DecoupledLionW_8bit(params, **optimizer_config)
else:
raise ValueError(f'Not sure how to build optimizer: {name}')

Expand Down

0 comments on commit 26bbc57

Please sign in to comment.