Skip to content

Commit

Permalink
Added Exponent Normal Smearing for radial basis functions (#366)
Browse files Browse the repository at this point in the history
* improve TensorNet model coverage

* Update pyproject.toml

Signed-off-by: Tsz Wai Ko <[email protected]>

* Improve the unit test for SO(3) equivarance in TensorNet class

* improve SO3Net model class coverage and simplify TensorNet implementations

* improve the coverage in MLP_norm class

* Improve the implementation of three-body interactions

* fixed black

* Optimize the speed of _compute_3body class

* type checking is added for scheduler

* update M3GNet Potential training notebook for the demonstration of obtaining and using element offsets

* Downgrade sympy to avoid crash of SO3 operations

* Smooth l1 loss function is added and united tests are improved

* merge the method predict_structure and featurize_structure into a function including both

* remove unnecessary else statement for training magmoms

* modify so3 operation implementation to make united tests pass due to the update of sympy

* skip test_load_all_models for MacOS pytest now

* Reference for CHGNet is added

* Update README.md and index.md for including CHGNet

Signed-off-by: Tsz Wai Ko <[email protected]>

* add more description for using CHGNet pretrained models in Relaxations and Simulations using the M3GNet Universal Potential.ipynb

* A command-line interface for performing ASE MD simulations is added

* added back py.typed

* ExpNormal Smearing for radial basis functions is added

* Changed deprecated torch.scalar_tensor into torch.Tensor

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Tsz Wai Ko <[email protected]>

* Converted the float number into tensor

Signed-off-by: Tsz Wai Ko <[email protected]>

---------

Signed-off-by: Tsz Wai Ko <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
kenko911 and coderabbitai[bot] committed Sep 23, 2024
1 parent 3a4f26d commit 3bc8467
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 3 deletions.
56 changes: 56 additions & 0 deletions src/matgl/layers/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import matgl
from matgl.layers._three_body import combine_sbf_shf
from matgl.utils.cutoff import cosine_cutoff
from matgl.utils.maths import SPHERICAL_BESSEL_ROOTS, _get_lambda_func


Expand Down Expand Up @@ -360,3 +361,58 @@ def forward(self, line_graph):
sbf = self.sbf(line_graph.edata["triple_bond_lengths"])
shf = self.shf(line_graph.edata["cos_theta"], line_graph.edata["phi"])
return combine_sbf_shf(sbf, shf, max_n=self.max_n, max_l=self.max_l, use_phi=self.use_phi)


class ExpNormalFunction(nn.Module):
"""Implementation of radial basis function using exponential normal smearing."""

def __init__(self, cutoff: float = 5.0, num_rbf: int = 50, learnable: bool = True):
"""
Initialize ExpNormalSmearing.
Args:
cutoff (float): The cutoff distance beyond which interactions are considered negligible. Default is 5.0.
num_rbf (int): The number of radial basis functions (RBF) to use. Default is 50.
learnable (bool): If True, the means and betas parameters are learnable.
If False, they are fixed. Default is True.
"""
super().__init__()
self.cutoff = cutoff
self.num_rbf = num_rbf
self.learnable = learnable

self.alpha = 5.0 / cutoff

means, betas = self._initial_params()
if learnable:
self.register_parameter("means", nn.Parameter(means))
self.register_parameter("betas", nn.Parameter(betas))
else:
self.register_buffer("means", means)
self.register_buffer("betas", betas)

def _initial_params(self):
"""Initialize the means and betas parameters."""
start_value = torch.exp(torch.tensor(-self.cutoff, dtype=matgl.float_th))
means = torch.linspace(start_value, 1, self.num_rbf)
betas = torch.tensor([(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf)
return means, betas

def reset_parameters(self):
"""Reset the means and betas to their initial values."""
means, betas = self._initial_params()
self.means.data.copy_(means)
self.betas.data.copy_(betas)

def forward(self, r: torch.Tensor):
"""
Compute the radial basis function for the input distances.
Args:
r (torch.Tensor): Input distances.
Returns:
torch.Tensor: Smearing function applied to the input distances.
"""
r = r.unsqueeze(-1)
return cosine_cutoff(r, self.cutoff) * torch.exp(-self.betas * (torch.exp(self.alpha * (-r)) - self.means) ** 2)
8 changes: 5 additions & 3 deletions src/matgl/layers/_bond.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from torch import nn

from matgl.layers._basis import GaussianExpansion, SphericalBesselFunction
from matgl.layers._basis import ExpNormalFunction, GaussianExpansion, SphericalBesselFunction


class BondExpansion(nn.Module):
Expand All @@ -18,7 +18,7 @@ def __init__(
max_l: int = 3,
max_n: int = 3,
cutoff: float = 5.0,
rbf_type: Literal["SphericalBessel", "Gaussian"] = "SphericalBessel",
rbf_type: Literal["SphericalBessel", "Gaussian", "ExpNorm"] = "SphericalBessel",
smooth: bool = False,
initial: float = 0.0,
final: float = 5.0,
Expand All @@ -30,7 +30,7 @@ def __init__(
max_l (int): order of angular part
max_n (int): order of radial part
cutoff (float): cutoff radius
rbf_type (str): type of radial basis function .i.e. either "SphericalBessel" or 'Gaussian'
rbf_type (str): type of radial basis function .i.e. either "SphericalBessel", "ExpNorm" or 'Gaussian'
smooth (bool): whether apply the smooth version of spherical bessel functions or not
initial (float): initial point for gaussian expansion
final (float): final point for gaussian expansion
Expand All @@ -53,6 +53,8 @@ def __init__(
self.rbf = SphericalBesselFunction(max_l, max_n, cutoff, smooth) # type: ignore
elif rbf_type.lower() == "gaussian":
self.rbf = GaussianExpansion(initial, final, num_centers, width) # type: ignore
elif rbf_type.lower() == "expnorm":
self.rbf = ExpNormalFunction(cutoff, num_centers, True)
else:
raise ValueError("Undefined rbf_type, please use SphericalBessel or Gaussian instead.")

Expand Down
12 changes: 12 additions & 0 deletions tests/layers/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
create_line_graph,
)
from matgl.layers._basis import (
ExpNormalFunction,
FourierExpansion,
GaussianExpansion,
RadialBesselFunction,
Expand Down Expand Up @@ -57,6 +58,17 @@ def test_spherical_bessel_function():
assert [rbf.size(dim=0), rbf.size(dim=1)] == [11, 3]


def test_exp_normal_function():
r = torch.linspace(1.0, 5.0, 11)
rbf = ExpNormalFunction(cutoff=5.0, num_rbf=3, learnable=False)
res = rbf(r)
assert [res.size(dim=0), res.size(dim=1)] == [11, 3]

rbf = ExpNormalFunction(cutoff=5.0, num_rbf=3, learnable=True)
res = rbf(r)
assert [res.size(dim=0), res.size(dim=1)] == [11, 3]


def test_spherical_harmonic_function():
theta = torch.linspace(-1, 1, 10)
phi = torch.linspace(0, 2 * np.pi, 10)
Expand Down
7 changes: 7 additions & 0 deletions tests/layers/test_bond.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ def test_spherical_bessel(self, graph_MoS, graph_CO):
bond_basis = bond_expansion(g2.edata["bond_dist"])
assert bond_basis.shape == (2, 9)

def test_exp_normal(self, graph_MoS, graph_CO):
_, g1, _ = graph_MoS
_, g2, _ = graph_CO
bond_expansion = BondExpansion(rbf_type="ExpNorm", cutoff=4.0, num_centers=9)
bond_basis = bond_expansion(g1.edata["bond_dist"])
assert bond_basis.shape == (28, 9)

def test_exception(self):
with pytest.raises(ValueError, match="Undefined rbf_type"):
BondExpansion(rbf_type="nonsense")

0 comments on commit 3bc8467

Please sign in to comment.