diff --git a/matgl/apps/pes.py b/matgl/apps/pes.py index 85bbe683..b4c168e8 100644 --- a/matgl/apps/pes.py +++ b/matgl/apps/pes.py @@ -56,26 +56,13 @@ def __init__( self.data_mean = data_mean if data_mean is not None else torch.zeros(1) self.data_std = data_std if data_std is not None else torch.ones(1) - def forward( - self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, l_g: dgl.DGLGraph | None = None - ) -> tuple: - """Args: - g: DGL graph - state_attr: State attrs - l_g: Line graph. - - Returns: - energies, forces, stresses, hessian: torch.Tensor - """ + def _calc_forces_stresses_hessian( + self, g: dgl.DGLGraph, total_energies: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate optional forces, stresses and hessian.""" forces = torch.zeros(1) stresses = torch.zeros(1) hessian = torch.zeros(1) - if self.calc_forces: - g.ndata["pos"].requires_grad_(True) - total_energies = self.data_std * self.model(g=g, state_attr=state_attr, l_g=l_g) + self.data_mean - if self.element_refs is not None: - property_offset = torch.squeeze(self.element_refs(g)) - total_energies += property_offset if self.calc_forces: grads = grad( @@ -118,4 +105,38 @@ def forward( num_nodes = g.batch_num_nodes()[graph_id] count_node = count_node + num_nodes stresses = torch.cat(sts) + + return forces, stresses, hessian + + def forward( + self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, l_g: dgl.DGLGraph | None = None + ) -> tuple[torch.Tensor, ...]: + """Args: + g: DGL graph + state_attr: State attrs + l_g: Line graph. + + Returns: + energies, forces, stresses, hessian: torch.Tensor + """ + if self.calc_forces: + g.ndata["pos"].requires_grad_(True) + + predictions = self.model(g, state_attr, l_g) + if isinstance(predictions, tuple) and len(predictions) > 1: + total_energies, site_wise = predictions + else: + total_energies = predictions + site_wise = None + + total_energies = self.data_std * total_energies + self.data_mean + if self.element_refs is not None: + property_offset = torch.squeeze(self.element_refs(g)) + total_energies += property_offset + + forces, stresses, hessian = self._calc_forces_stresses_hessian(g, total_energies) + + if site_wise is not None: + return total_energies, forces, stresses, hessian, site_wise + return total_energies, forces, stresses, hessian