From 8aad57f7435ec5a7adc9794420a24231c3c31a66 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 14 Dec 2023 21:20:09 +0000 Subject: [PATCH] enable config to have function --- llmfoundry/models/layers/ffn.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index cdd6e376d7..5a86b9424c 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -37,10 +37,21 @@ def resolve_ffn_act_fn( Callable[[torch.Tensor], torch.Tensor]: The activation function. """ config = deepcopy(config or _FFN_ACT_FN_DEFAULT) - name = config.pop('name') - if not hasattr(torch.nn.functional, name): - raise ValueError(f'Unrecognised activation function name ({name}).') - act = getattr(torch.nn.functional, name) + + if 'function' in config.keys(): + act = config.pop('function') + if not isinstance(act, Callable): + raise ValueError(f'act `function` ({act}) must be Callable.') + elif 'name' in config.keys(): + name = config.pop('name') + if not hasattr(torch.nn.functional, name): + raise ValueError(f'Unrecognised activation function name ({name}).') + act = getattr(torch.nn.functional, name) + else: + raise ValueError( + f'FFN activation function config must specify either `function` or function `name`.' + ) + return partial(act, **config)