diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index cdd6e376d7..5a86b9424c 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -37,10 +37,21 @@ def resolve_ffn_act_fn( Callable[[torch.Tensor], torch.Tensor]: The activation function. """ config = deepcopy(config or _FFN_ACT_FN_DEFAULT) - name = config.pop('name') - if not hasattr(torch.nn.functional, name): - raise ValueError(f'Unrecognised activation function name ({name}).') - act = getattr(torch.nn.functional, name) + + if 'function' in config.keys(): + act = config.pop('function') + if not isinstance(act, Callable): + raise ValueError(f'act `function` ({act}) must be Callable.') + elif 'name' in config.keys(): + name = config.pop('name') + if not hasattr(torch.nn.functional, name): + raise ValueError(f'Unrecognised activation function name ({name}).') + act = getattr(torch.nn.functional, name) + else: + raise ValueError( + f'FFN activation function config must specify either `function` or function `name`.' + ) + return partial(act, **config)