From 91a538a6bdc5b73ebf589665cb4cd6a16bb09582 Mon Sep 17 00:00:00 2001 From: Misko Date: Wed, 24 Jul 2024 21:59:17 +0000 Subject: [PATCH] pain to bb and heads --- src/fairchem/core/models/painn/painn.py | 105 ++++++++++++++++-------- 1 file changed, 69 insertions(+), 36 deletions(-) diff --git a/src/fairchem/core/models/painn/painn.py b/src/fairchem/core/models/painn/painn.py index 8843f02b2..d86a04c39 100644 --- a/src/fairchem/core/models/painn/painn.py +++ b/src/fairchem/core/models/painn/painn.py @@ -50,8 +50,8 @@ from .utils import get_edge_id, repeat_blocks -@registry.register_model("painn") -class PaiNN(BaseModel): +@registry.register_model("painn_backbone") +class PaiNNBB(BaseModel): r"""PaiNN model based on the description in Schütt et al. (2021): Equivariant message passing for the prediction of tensorial properties and molecular spectra, https://arxiv.org/abs/2102.03150. @@ -116,19 +116,8 @@ def __init__( self.update_layers.append(PaiNNUpdate(hidden_channels)) setattr(self, "upd_out_scalar_scale_%d" % i, ScaleFactor()) - self.out_energy = nn.Sequential( - nn.Linear(hidden_channels, hidden_channels // 2), - ScaledSiLU(), - nn.Linear(hidden_channels // 2, 1), - ) - - if self.regress_forces is True and self.direct_forces is True: - self.out_forces = PaiNNOutput(hidden_channels) - self.inv_sqrt_2 = 1 / math.sqrt(2.0) - self.reset_parameters() - load_scales_compat(self, scale_file) def reset_parameters(self) -> None: @@ -361,7 +350,6 @@ def generate_graph_values(self, data): @conditional_grad(torch.enable_grad()) def forward(self, data): pos = data.pos - batch = data.batch z = data.atomic_numbers.long() if self.regress_forces and not self.direct_forces: @@ -398,28 +386,7 @@ def forward(self, data): vec = vec + dvec x = getattr(self, "upd_out_scalar_scale_%d" % i)(x) - #### Output block ##################################################### - - per_atom_energy = self.out_energy(x).squeeze(1) - energy = scatter(per_atom_energy, batch, dim=0) - outputs = {"energy": energy} - - if self.regress_forces: - if self.direct_forces: - forces = self.out_forces(x, vec) - else: - forces = ( - -1 - * torch.autograd.grad( - x, - pos, - grad_outputs=torch.ones_like(x), - create_graph=True, - )[0] - ) - outputs["forces"] = forces - - return outputs + return {"x": x, "vec": vec} @property def num_params(self) -> int: @@ -625,3 +592,69 @@ def forward(self, x, v): x = self.act(x) return x, v + + +@registry.register_model("painn_energy_head") +class PaiNN_energy_head(nn.Module): + def __init__(self, backbone, backbone_config, head_config): + super().__init__() + + self.out_energy = nn.Sequential( + nn.Linear(backbone.hidden_channels, backbone.hidden_channels // 2), + ScaledSiLU(), + nn.Linear(backbone.hidden_channels // 2, 1), + ) + + nn.init.xavier_uniform_(self.out_energy[0].weight) + self.out_energy[0].bias.data.fill_(0) + nn.init.xavier_uniform_(self.out_energy[2].weight) + self.out_energy[2].bias.data.fill_(0) + + def forward(self, x, emb): + per_atom_energy = self.out_energy(emb["x"]).squeeze(1) + return scatter(per_atom_energy, x.batch, dim=0) + + +@registry.register_model("painn_force_head") +class PaiNN_force_head(nn.Module): + def __init__(self, backbone, backbone_config, head_config): + super().__init__() + self.direct_forces = backbone.direct_forces + + if self.direct_forces: + self.out_forces = PaiNNOutput(backbone.hidden_channels) + + def forward(self, x, emb): + if self.direct_forces: + forces = self.out_forces(emb["x"], emb["vec"]) + else: + forces = ( + -1 + * torch.autograd.grad( + emb["x"], + x.pos, + grad_outputs=torch.ones_like(emb["x"]), + create_graph=True, + )[0] + ) + return forces + + +@registry.register_model("painn") +class PaiNN(PaiNNBB): + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.energy_head = PaiNN_energy_head(self, {}, {}) + self.force_head = PaiNN_force_head(self, {}, {}) + + @conditional_grad(torch.enable_grad()) + def forward(self, data): + bb_outputs = super().forward(data) + + outputs = {"energy": self.energy_head(data, bb_outputs)} + if self.regress_forces: + outputs["forces"] = self.force_head(data, bb_outputs) + + return outputs