Skip to content

Commit

Permalink
enable config to have function
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Dec 14, 2023
1 parent 87305df commit 8aad57f
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,21 @@ def resolve_ffn_act_fn(
Callable[[torch.Tensor], torch.Tensor]: The activation function.
"""
config = deepcopy(config or _FFN_ACT_FN_DEFAULT)
name = config.pop('name')
if not hasattr(torch.nn.functional, name):
raise ValueError(f'Unrecognised activation function name ({name}).')
act = getattr(torch.nn.functional, name)

if 'function' in config.keys():
act = config.pop('function')
if not isinstance(act, Callable):
raise ValueError(f'act `function` ({act}) must be Callable.')
elif 'name' in config.keys():
name = config.pop('name')
if not hasattr(torch.nn.functional, name):
raise ValueError(f'Unrecognised activation function name ({name}).')
act = getattr(torch.nn.functional, name)
else:
raise ValueError(
f'FFN activation function config must specify either `function` or function `name`.'
)

return partial(act, **config)


Expand Down

0 comments on commit 8aad57f

Please sign in to comment.