diff --git a/matgl/layers/__init__.py b/matgl/layers/__init__.py index fce5f89d..aee1fdc8 100644 --- a/matgl/layers/__init__.py +++ b/matgl/layers/__init__.py @@ -1,7 +1,7 @@ """This package implements the layers for M*GNet.""" from __future__ import annotations -from matgl.layers._activations import SoftExponential, SoftPlus2 +from matgl.layers._activations import ActivationFunction from matgl.layers._atom_ref import AtomRef from matgl.layers._basis import FourierExpansion, RadialBesselFunction, SphericalBesselWithHarmonics from matgl.layers._bond import BondExpansion diff --git a/matgl/layers/_activations.py b/matgl/layers/_activations.py index 5f238365..4c6b4124 100644 --- a/matgl/layers/_activations.py +++ b/matgl/layers/_activations.py @@ -2,6 +2,7 @@ from __future__ import annotations import math +from enum import Enum import torch from torch import nn @@ -70,3 +71,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.alpha < 0.0: return -torch.log(1.0 - self.alpha * (x + self.alpha)) / self.alpha return (torch.exp(self.alpha * x) - 1.0) / self.alpha + self.alpha + + +class ActivationFunction(Enum): + """Enumeration of optional activation functions.""" + + swish = nn.SiLU + sigmoid = nn.Sigmoid + tanh = nn.Tanh + softplus = nn.Softplus + softplus2 = SoftPlus2 + softexp = SoftExponential diff --git a/matgl/models/_m3gnet.py b/matgl/models/_m3gnet.py index 4cb80a98..210e1436 100644 --- a/matgl/models/_m3gnet.py +++ b/matgl/models/_m3gnet.py @@ -25,14 +25,13 @@ ) from matgl.layers import ( MLP, + ActivationFunction, BondExpansion, EmbeddingBlock, GatedMLP, M3GNetBlock, ReduceReadOut, Set2SetReadOut, - SoftExponential, - SoftPlus2, SphericalBesselWithHarmonics, ThreeBodyInteractions, WeightedReadOut, @@ -113,18 +112,12 @@ def __init__( self.save_args(locals(), kwargs) - if activation_type == "swish": - activation = nn.SiLU() # type: ignore - elif activation_type == "tanh": - activation = nn.Tanh() # type: ignore - elif activation_type == "sigmoid": - activation = nn.Sigmoid() # type: ignore - elif activation_type == "softplus2": - activation = SoftPlus2() # type: ignore - elif activation_type == "softexp": - activation = SoftExponential() # type: ignore - else: - raise ValueError("Invalid activation type, please try using swish, sigmoid, tanh, softplus2, softexp") + try: + activation: nn.Module = ActivationFunction[activation_type].value() + except KeyError: + raise ValueError( + f"Invalid activation type, please try using one of {[af.name for af in ActivationFunction]}" + ) from None if element_types is None: self.element_types = DEFAULT_ELEMENT_TYPES diff --git a/matgl/models/_megnet.py b/matgl/models/_megnet.py index 02ac9755..0ccf3a87 100644 --- a/matgl/models/_megnet.py +++ b/matgl/models/_megnet.py @@ -17,7 +17,7 @@ from matgl.config import DEFAULT_ELEMENT_TYPES from matgl.graph.compute import compute_pair_vector_and_distance -from matgl.layers import MLP, BondExpansion, EdgeSet2Set, EmbeddingBlock, MEGNetBlock, SoftExponential, SoftPlus2 +from matgl.layers import MLP, ActivationFunction, BondExpansion, EdgeSet2Set, EmbeddingBlock, MEGNetBlock from matgl.utils.io import IOMixIn if TYPE_CHECKING: @@ -96,19 +96,12 @@ def __init__( edge_dims = [dim_edge_embedding, *hidden_layer_sizes_input] state_dims = [dim_state_embedding, *hidden_layer_sizes_input] - activation: nn.Module - if activation_type == "swish": - activation = nn.SiLU() - elif activation_type == "sigmoid": - activation = nn.Sigmoid() - elif activation_type == "tanh": - activation = nn.Tanh() - elif activation_type == "softplus2": - activation = SoftPlus2() - elif activation_type == "softexp": - activation = SoftExponential() - else: - raise ValueError("Invalid activation type, please try using swish, sigmoid, tanh, softplus2, softexp") + try: + activation: nn.Module = ActivationFunction[activation_type].value() + except KeyError: + raise ValueError( + f"Invalid activation type, please try using one of {[af.name for af in ActivationFunction]}" + ) from None self.embedding = EmbeddingBlock( degree_rbf=dim_edge_embedding,