Skip to content

Commit

Permalink
ENH: use enum for available activation functions (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque authored Aug 8, 2023
1 parent 5949ed8 commit e6ce8b1
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 29 deletions.
2 changes: 1 addition & 1 deletion matgl/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""This package implements the layers for M*GNet."""
from __future__ import annotations

from matgl.layers._activations import SoftExponential, SoftPlus2
from matgl.layers._activations import ActivationFunction
from matgl.layers._atom_ref import AtomRef
from matgl.layers._basis import FourierExpansion, RadialBesselFunction, SphericalBesselWithHarmonics
from matgl.layers._bond import BondExpansion
Expand Down
12 changes: 12 additions & 0 deletions matgl/layers/_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import math
from enum import Enum

import torch
from torch import nn
Expand Down Expand Up @@ -70,3 +71,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.alpha < 0.0:
return -torch.log(1.0 - self.alpha * (x + self.alpha)) / self.alpha
return (torch.exp(self.alpha * x) - 1.0) / self.alpha + self.alpha


class ActivationFunction(Enum):
"""Enumeration of optional activation functions."""

swish = nn.SiLU
sigmoid = nn.Sigmoid
tanh = nn.Tanh
softplus = nn.Softplus
softplus2 = SoftPlus2
softexp = SoftExponential
21 changes: 7 additions & 14 deletions matgl/models/_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@
)
from matgl.layers import (
MLP,
ActivationFunction,
BondExpansion,
EmbeddingBlock,
GatedMLP,
M3GNetBlock,
ReduceReadOut,
Set2SetReadOut,
SoftExponential,
SoftPlus2,
SphericalBesselWithHarmonics,
ThreeBodyInteractions,
WeightedReadOut,
Expand Down Expand Up @@ -113,18 +112,12 @@ def __init__(

self.save_args(locals(), kwargs)

if activation_type == "swish":
activation = nn.SiLU() # type: ignore
elif activation_type == "tanh":
activation = nn.Tanh() # type: ignore
elif activation_type == "sigmoid":
activation = nn.Sigmoid() # type: ignore
elif activation_type == "softplus2":
activation = SoftPlus2() # type: ignore
elif activation_type == "softexp":
activation = SoftExponential() # type: ignore
else:
raise ValueError("Invalid activation type, please try using swish, sigmoid, tanh, softplus2, softexp")
try:
activation: nn.Module = ActivationFunction[activation_type].value()
except KeyError:
raise ValueError(
f"Invalid activation type, please try using one of {[af.name for af in ActivationFunction]}"
) from None

if element_types is None:
self.element_types = DEFAULT_ELEMENT_TYPES
Expand Down
21 changes: 7 additions & 14 deletions matgl/models/_megnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from matgl.config import DEFAULT_ELEMENT_TYPES
from matgl.graph.compute import compute_pair_vector_and_distance
from matgl.layers import MLP, BondExpansion, EdgeSet2Set, EmbeddingBlock, MEGNetBlock, SoftExponential, SoftPlus2
from matgl.layers import MLP, ActivationFunction, BondExpansion, EdgeSet2Set, EmbeddingBlock, MEGNetBlock
from matgl.utils.io import IOMixIn

if TYPE_CHECKING:
Expand Down Expand Up @@ -96,19 +96,12 @@ def __init__(
edge_dims = [dim_edge_embedding, *hidden_layer_sizes_input]
state_dims = [dim_state_embedding, *hidden_layer_sizes_input]

activation: nn.Module
if activation_type == "swish":
activation = nn.SiLU()
elif activation_type == "sigmoid":
activation = nn.Sigmoid()
elif activation_type == "tanh":
activation = nn.Tanh()
elif activation_type == "softplus2":
activation = SoftPlus2()
elif activation_type == "softexp":
activation = SoftExponential()
else:
raise ValueError("Invalid activation type, please try using swish, sigmoid, tanh, softplus2, softexp")
try:
activation: nn.Module = ActivationFunction[activation_type].value()
except KeyError:
raise ValueError(
f"Invalid activation type, please try using one of {[af.name for af in ActivationFunction]}"
) from None

self.embedding = EmbeddingBlock(
degree_rbf=dim_edge_embedding,
Expand Down

0 comments on commit e6ce8b1

Please sign in to comment.