Skip to content

Commit

Permalink
pain to bb and heads
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 24, 2024
1 parent 42f1a11 commit 91a538a
Showing 1 changed file with 69 additions and 36 deletions.
105 changes: 69 additions & 36 deletions src/fairchem/core/models/painn/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
from .utils import get_edge_id, repeat_blocks


@registry.register_model("painn")
class PaiNN(BaseModel):
@registry.register_model("painn_backbone")
class PaiNNBB(BaseModel):
r"""PaiNN model based on the description in Schütt et al. (2021):
Equivariant message passing for the prediction of tensorial properties
and molecular spectra, https://arxiv.org/abs/2102.03150.
Expand Down Expand Up @@ -116,19 +116,8 @@ def __init__(
self.update_layers.append(PaiNNUpdate(hidden_channels))
setattr(self, "upd_out_scalar_scale_%d" % i, ScaleFactor())

self.out_energy = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels // 2),
ScaledSiLU(),
nn.Linear(hidden_channels // 2, 1),
)

if self.regress_forces is True and self.direct_forces is True:
self.out_forces = PaiNNOutput(hidden_channels)

self.inv_sqrt_2 = 1 / math.sqrt(2.0)

self.reset_parameters()

load_scales_compat(self, scale_file)

def reset_parameters(self) -> None:
Expand Down Expand Up @@ -361,7 +350,6 @@ def generate_graph_values(self, data):
@conditional_grad(torch.enable_grad())
def forward(self, data):
pos = data.pos
batch = data.batch
z = data.atomic_numbers.long()

if self.regress_forces and not self.direct_forces:
Expand Down Expand Up @@ -398,28 +386,7 @@ def forward(self, data):
vec = vec + dvec
x = getattr(self, "upd_out_scalar_scale_%d" % i)(x)

#### Output block #####################################################

per_atom_energy = self.out_energy(x).squeeze(1)
energy = scatter(per_atom_energy, batch, dim=0)
outputs = {"energy": energy}

if self.regress_forces:
if self.direct_forces:
forces = self.out_forces(x, vec)
else:
forces = (
-1
* torch.autograd.grad(
x,
pos,
grad_outputs=torch.ones_like(x),
create_graph=True,
)[0]
)
outputs["forces"] = forces

return outputs
return {"x": x, "vec": vec}

@property
def num_params(self) -> int:
Expand Down Expand Up @@ -625,3 +592,69 @@ def forward(self, x, v):

x = self.act(x)
return x, v


@registry.register_model("painn_energy_head")
class PaiNN_energy_head(nn.Module):
def __init__(self, backbone, backbone_config, head_config):
super().__init__()

self.out_energy = nn.Sequential(
nn.Linear(backbone.hidden_channels, backbone.hidden_channels // 2),
ScaledSiLU(),
nn.Linear(backbone.hidden_channels // 2, 1),
)

nn.init.xavier_uniform_(self.out_energy[0].weight)
self.out_energy[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_energy[2].weight)
self.out_energy[2].bias.data.fill_(0)

def forward(self, x, emb):
per_atom_energy = self.out_energy(emb["x"]).squeeze(1)
return scatter(per_atom_energy, x.batch, dim=0)


@registry.register_model("painn_force_head")
class PaiNN_force_head(nn.Module):
def __init__(self, backbone, backbone_config, head_config):
super().__init__()
self.direct_forces = backbone.direct_forces

if self.direct_forces:
self.out_forces = PaiNNOutput(backbone.hidden_channels)

def forward(self, x, emb):
if self.direct_forces:
forces = self.out_forces(emb["x"], emb["vec"])
else:
forces = (
-1
* torch.autograd.grad(
emb["x"],
x.pos,
grad_outputs=torch.ones_like(emb["x"]),
create_graph=True,
)[0]
)
return forces


@registry.register_model("painn")
class PaiNN(PaiNNBB):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.energy_head = PaiNN_energy_head(self, {}, {})
self.force_head = PaiNN_force_head(self, {}, {})

@conditional_grad(torch.enable_grad())
def forward(self, data):
bb_outputs = super().forward(data)

outputs = {"energy": self.energy_head(data, bb_outputs)}
if self.regress_forces:
outputs["forces"] = self.force_head(data, bb_outputs)

return outputs

0 comments on commit 91a538a

Please sign in to comment.