Skip to content

Commit

Permalink
Add potential with additional site wise predictions (#125)
Browse files Browse the repository at this point in the history
* ENH: add potential with additional site wise predictions

* STY: linting

* ENH: single Potential class
  • Loading branch information
lbluque authored Aug 8, 2023
1 parent 90c62d1 commit fc5cdf7
Showing 1 changed file with 38 additions and 17 deletions.
55 changes: 38 additions & 17 deletions matgl/apps/pes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,13 @@ def __init__(
self.data_mean = data_mean if data_mean is not None else torch.zeros(1)
self.data_std = data_std if data_std is not None else torch.ones(1)

def forward(
self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, l_g: dgl.DGLGraph | None = None
) -> tuple:
"""Args:
g: DGL graph
state_attr: State attrs
l_g: Line graph.
Returns:
energies, forces, stresses, hessian: torch.Tensor
"""
def _calc_forces_stresses_hessian(
self, g: dgl.DGLGraph, total_energies: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculate optional forces, stresses and hessian."""
forces = torch.zeros(1)
stresses = torch.zeros(1)
hessian = torch.zeros(1)
if self.calc_forces:
g.ndata["pos"].requires_grad_(True)
total_energies = self.data_std * self.model(g=g, state_attr=state_attr, l_g=l_g) + self.data_mean
if self.element_refs is not None:
property_offset = torch.squeeze(self.element_refs(g))
total_energies += property_offset

if self.calc_forces:
grads = grad(
Expand Down Expand Up @@ -118,4 +105,38 @@ def forward(
num_nodes = g.batch_num_nodes()[graph_id]
count_node = count_node + num_nodes
stresses = torch.cat(sts)

return forces, stresses, hessian

def forward(
self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, l_g: dgl.DGLGraph | None = None
) -> tuple[torch.Tensor, ...]:
"""Args:
g: DGL graph
state_attr: State attrs
l_g: Line graph.
Returns:
energies, forces, stresses, hessian: torch.Tensor
"""
if self.calc_forces:
g.ndata["pos"].requires_grad_(True)

predictions = self.model(g, state_attr, l_g)
if isinstance(predictions, tuple) and len(predictions) > 1:
total_energies, site_wise = predictions
else:
total_energies = predictions
site_wise = None

total_energies = self.data_std * total_energies + self.data_mean
if self.element_refs is not None:
property_offset = torch.squeeze(self.element_refs(g))
total_energies += property_offset

forces, stresses, hessian = self._calc_forces_stresses_hessian(g, total_energies)

if site_wise is not None:
return total_energies, forces, stresses, hessian, site_wise

return total_energies, forces, stresses, hessian

0 comments on commit fc5cdf7

Please sign in to comment.