Skip to content

Commit

Permalink
add doc string
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Nov 23, 2023
1 parent 26bbc57 commit f39592a
Showing 1 changed file with 56 additions and 1 deletion.
57 changes: 56 additions & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,62 @@ def extract_param_groups(
model: torch.nn.Module,
optimizer_config: Dict[str, Any],
) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]:

"""Extracts parameter groups defined in the optimizer config.
The optimizer_config defines the optimizer args. It can additionally have key
`disable_grad` which is a string or list of strings. If a string matches a
parameter name, then that parameter will have `requires_grad=False`. This is
useful for freezing parameters. It can additionally have a key
`param_groups` which is a list of dicts. In this dict, key `param_str_match`
defines a string; if a parameter name contains this string, then it will be
in this parameter group. This is useful for grouping parameters together.
The dict can also contain any other key that is a valid optimizer arg.
Usage
To disable gradient for all parameters that contain the string "norm" or "bias":
```
optimizer_config: {
"name": "decoupled_lionw",
"lr": 1e-3,
"weight_decay": 1e-2,
"betas": [0.9, 0.999],
"eps": 1e-8,
"disable_grad": ["norm", "bias"]
}
```
To create modify the optimizer parameters for all parameters that contain the
string "norm" and "bias" seperately:
```
optimizer_config: {
"name": "decoupled_lionw",
"lr": 1e-3,
"weight_decay": 1e-2,
"betas": [0.9, 0.999],
"eps": 1e-8,
"param_groups": [
{
"param_str_match": "norm",
"lr": 1e-4,
"weight_decay": 0.0,
},
{
"param_str_match": "bias",
"lr": 5e-4,
"weight_decay": 0.0,
},
],
}
```
Args:
model (torch.nn.Module): model to extract parameters from
optimizer_config (Dict[str, Any]): optimizer config
Returns:
Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of
torch.Tensor's or dict's. Specifies what Tensors should be optimized.
"""
if 'disable_grad' in optimizer_config.keys():
str_match = optimizer_config.pop('disable_grad')
if isinstance(str_match, str):
Expand Down

0 comments on commit f39592a

Please sign in to comment.