From c05eaae8b4694094d17c579c9172bab572b51b12 Mon Sep 17 00:00:00 2001 From: Misko Date: Fri, 2 Aug 2024 14:50:24 -0700 Subject: [PATCH] Move select models to backbone + heads format and add support for hydra (#782) * convert escn to bb + heads * convert dimenet to bb + heads * gemnet_oc to backbone and heads * add additional parameter backbone config to heads * gemnet to bb and heads * pain to bb and heads * add eqv2 bb+heads; move to canonical naming * fix calculator loading by leaving original class in code * fix issues with calculator loading * lint fixes * move dimenet++ heads to one * add test for dimenet * add painn test * hydra and tests for gemnetH dppH painnH * add escnH and equiformerv2H * add gemnetdt gemnetdtH * add smoke test for schnet and scn * remove old examples * typo * fix gemnet with grad forces; add test for this * remove unused params; add backbone and head interface; add typing * remove unused second order output heads * remove OC20 suffix from equiformer * remove comment * rename and lint * fix dimenet test * fix tests * refactor generate graph * refactor generate graph * fix a messy cherry pick * final messy fix * graph data interface in eqv2 * refactor * no bbconfigs * no more headconfigs in inits * rename hydra * fix eqV2 * update test configs * final fixes * fix tutorial * rm comments * fix test --------- Co-authored-by: lbluque Co-authored-by: Luis Barroso-Luque (cherry picked from commit 08b8c1ea9f1858d7f8f14df1718f26997c1ca799) --- docs/legacy_tutorials/OCP_Tutorial.md | 2 +- src/fairchem/core/models/base.py | 137 +++++++- src/fairchem/core/models/dimenet_plus_plus.py | 147 +++++++-- .../core/models/equiformer_v2/__init__.py | 2 +- ...equiformer_v2_oc20.py => equiformer_v2.py} | 297 +++++++++++++++--- src/fairchem/core/models/escn/escn.py | 177 +++++++++-- src/fairchem/core/models/gemnet/gemnet.py | 195 +++++++++--- src/fairchem/core/models/gemnet_gp/gemnet.py | 64 ++-- .../core/models/gemnet_oc/gemnet_oc.py | 287 ++++++++++++++--- src/fairchem/core/models/painn/painn.py | 130 ++++++-- src/fairchem/core/models/schnet.py | 26 +- src/fairchem/core/models/scn/scn.py | 31 +- src/fairchem/core/trainers/base_trainer.py | 13 +- tests/core/e2e/test_s2ef.py | 56 +++- tests/core/models/test_configs/test_dpp.yml | 50 +++ .../models/test_configs/test_dpp_hydra.yml | 55 ++++ .../test_configs/test_equiformerv2_hydra.yml | 98 ++++++ .../models/test_configs/test_escn_hydra.yml | 67 ++++ .../models/test_configs/test_gemnet_dt.yml | 79 +++++ .../test_configs/test_gemnet_dt_hydra.yml | 86 +++++ .../test_gemnet_dt_hydra_grad.yml | 84 +++++ .../{test_gemnet.yml => test_gemnet_oc.yml} | 0 .../test_configs/test_gemnet_oc_hydra.yml | 112 +++++++ .../test_gemnet_oc_hydra_grad.yml | 109 +++++++ tests/core/models/test_configs/test_painn.yml | 50 +++ .../models/test_configs/test_painn_hydra.yml | 58 ++++ .../core/models/test_configs/test_schnet.yml | 45 +++ tests/core/models/test_configs/test_scn.yml | 59 ++++ tests/core/models/test_dimenetpp.py | 3 - tests/core/models/test_equiformer_v2.py | 3 - tests/core/models/test_gemnet.py | 3 - tests/core/models/test_gemnet_oc.py | 3 - .../models/test_gemnet_oc_scaling_mismatch.py | 12 - tests/core/models/test_schnet.py | 2 +- 34 files changed, 2187 insertions(+), 355 deletions(-) rename src/fairchem/core/models/equiformer_v2/{equiformer_v2_oc20.py => equiformer_v2.py} (72%) create mode 100755 tests/core/models/test_configs/test_dpp.yml create mode 100755 tests/core/models/test_configs/test_dpp_hydra.yml create mode 100644 tests/core/models/test_configs/test_equiformerv2_hydra.yml create mode 100644 tests/core/models/test_configs/test_escn_hydra.yml create mode 100644 tests/core/models/test_configs/test_gemnet_dt.yml create mode 100644 tests/core/models/test_configs/test_gemnet_dt_hydra.yml create mode 100644 tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml rename tests/core/models/test_configs/{test_gemnet.yml => test_gemnet_oc.yml} (100%) create mode 100644 tests/core/models/test_configs/test_gemnet_oc_hydra.yml create mode 100644 tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml create mode 100644 tests/core/models/test_configs/test_painn.yml create mode 100644 tests/core/models/test_configs/test_painn_hydra.yml create mode 100755 tests/core/models/test_configs/test_schnet.yml create mode 100755 tests/core/models/test_configs/test_scn.yml diff --git a/docs/legacy_tutorials/OCP_Tutorial.md b/docs/legacy_tutorials/OCP_Tutorial.md index 8b5d4d522..19fd93f6b 100644 --- a/docs/legacy_tutorials/OCP_Tutorial.md +++ b/docs/legacy_tutorials/OCP_Tutorial.md @@ -1807,7 +1807,7 @@ Similarly, to predict forces, we pass edge features through a fully-connected la @registry.register_model("simple") class SimpleAtomEdgeModel(torch.nn.Module): - def __init__(self, num_atoms, bond_feat_dim, num_targets, emb_size=64, num_radial=64, cutoff=6.0, env_exponent=5): + def __init__(self, emb_size=64, num_radial=64, cutoff=6.0, env_exponent=5): super().__init__() self.radial_basis = RadialBasis( diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 42790643a..eb8c9d543 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -8,27 +8,42 @@ from __future__ import annotations import logging +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING import torch -import torch.nn as nn +from torch import nn from torch_geometric.nn import radius_graph +from fairchem.core.common.registry import registry from fairchem.core.common.utils import ( compute_neighbors, get_pbc_distances, radius_graph_pbc, ) +if TYPE_CHECKING: + from torch_geometric.data import Batch -class BaseModel(nn.Module): - def __init__(self, num_atoms=None, bond_feat_dim=None, num_targets=None) -> None: - super().__init__() - self.num_atoms = num_atoms - self.bond_feat_dim = bond_feat_dim - self.num_targets = num_targets - def forward(self, data): - raise NotImplementedError +@dataclass +class GraphData: + """Class to keep graph attributes nicely packaged.""" + + edge_index: torch.Tensor + edge_distance: torch.Tensor + edge_distance_vec: torch.Tensor + cell_offsets: torch.Tensor + offset_distances: torch.Tensor + neighbors: torch.Tensor + batch_full: torch.Tensor # used for GP functionality + atomic_numbers_full: torch.Tensor # used for GP functionality + node_offset: int = 0 # used for GP functionality + + +class GraphModelMixin: + """Mixin Model class implementing some general convenience properties and methods.""" def generate_graph( self, @@ -109,13 +124,16 @@ def generate_graph( ) neighbors = compute_neighbors(data, edge_index) - return ( - edge_index, - edge_dist, - distance_vec, - cell_offsets, - cell_offset_distances, - neighbors, + return GraphData( + edge_index=edge_index, + edge_distance=edge_dist, + edge_distance_vec=distance_vec, + cell_offsets=cell_offsets, + offset_distances=cell_offset_distances, + neighbors=neighbors, + node_offset=0, + batch_full=data.batch, + atomic_numbers_full=data.atomic_numbers.long(), ) @property @@ -130,3 +148,90 @@ def no_weight_decay(self) -> list: if "embedding" in name or "frequencies" in name or "bias" in name: no_wd_list.append(name) return no_wd_list + + +class HeadInterface(metaclass=ABCMeta): + @abstractmethod + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Head forward. + + Arguments + --------- + data: DataBatch + Atomic systems as input + emb: dict[str->torch.Tensor] + Embeddings of the input as generated by the backbone + + Returns + ------- + outputs: dict[str->torch.Tensor] + Return one or more targets generated by this head + """ + return + + +class BackboneInterface(metaclass=ABCMeta): + @abstractmethod + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + """Backbone forward. + + Arguments + --------- + data: DataBatch + Atomic systems as input + + Returns + ------- + embedding: dict[str->torch.Tensor] + Return backbone embeddings for the given input + """ + return + + +@registry.register_model("hydra") +class HydraModel(nn.Module, GraphModelMixin): + def __init__( + self, + backbone: dict, + heads: dict, + otf_graph: bool = True, + ): + super().__init__() + self.otf_graph = otf_graph + + backbone_model_name = backbone.pop("model") + self.backbone: BackboneInterface = registry.get_model_class( + backbone_model_name + )( + **backbone, + ) + + # Iterate through outputs_cfg and create heads + self.output_heads: dict[str, HeadInterface] = {} + + head_names_sorted = sorted(heads.keys()) + for head_name in head_names_sorted: + head_config = heads[head_name] + if "module" not in head_config: + raise ValueError( + f"{head_name} head does not specify module to use for the head" + ) + + module_name = head_config.pop("module") + self.output_heads[head_name] = registry.get_model_class(module_name)( + self.backbone, + **head_config, + ) + + self.output_heads = torch.nn.ModuleDict(self.output_heads) + + def forward(self, data: Batch): + emb = self.backbone(data) + # Predict all output properties for all structures in the batch for now. + out = {} + for k in self.output_heads: + out.update(self.output_heads[k](data, emb)) + + return out diff --git a/src/fairchem/core/models/dimenet_plus_plus.py b/src/fairchem/core/models/dimenet_plus_plus.py index 296a77bbb..aa08ea067 100644 --- a/src/fairchem/core/models/dimenet_plus_plus.py +++ b/src/fairchem/core/models/dimenet_plus_plus.py @@ -34,6 +34,8 @@ from __future__ import annotations +import typing + import torch from torch import nn from torch_geometric.nn.inits import glorot_orthogonal @@ -49,7 +51,10 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface + +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch try: import sympy as sym @@ -57,7 +62,7 @@ sym = None -class InteractionPPBlock(torch.nn.Module): +class InteractionPPBlock(nn.Module): def __init__( self, hidden_channels: int, @@ -90,11 +95,11 @@ def __init__( self.lin_up = nn.Linear(int_emb_size, hidden_channels, bias=False) # Residual layers before and after skip connection. - self.layers_before_skip = torch.nn.ModuleList( + self.layers_before_skip = nn.ModuleList( [ResidualLayer(hidden_channels, act) for _ in range(num_before_skip)] ) self.lin = nn.Linear(hidden_channels, hidden_channels) - self.layers_after_skip = torch.nn.ModuleList( + self.layers_after_skip = nn.ModuleList( [ResidualLayer(hidden_channels, act) for _ in range(num_after_skip)] ) @@ -153,7 +158,7 @@ def forward(self, x, rbf, sbf, idx_kj, idx_ji): return h -class OutputPPBlock(torch.nn.Module): +class OutputPPBlock(nn.Module): def __init__( self, num_radial: int, @@ -169,7 +174,7 @@ def __init__( self.lin_rbf = nn.Linear(num_radial, hidden_channels, bias=False) self.lin_up = nn.Linear(hidden_channels, out_emb_channels, bias=True) - self.lins = torch.nn.ModuleList() + self.lins = nn.ModuleList() for _ in range(num_layers): self.lins.append(nn.Linear(out_emb_channels, out_emb_channels)) self.lin = nn.Linear(out_emb_channels, out_channels, bias=False) @@ -193,7 +198,7 @@ def forward(self, x, rbf, i, num_nodes: int | None = None): return self.lin(x) -class DimeNetPlusPlus(torch.nn.Module): +class DimeNetPlusPlus(nn.Module): r"""DimeNet++ implementation based on https://github.com/klicperajo/dimenet. Args: @@ -241,7 +246,6 @@ def __init__( act = activation_resolver(act) super().__init__() - self.cutoff = cutoff if sym is None: @@ -256,7 +260,7 @@ def __init__( self.emb = EmbeddingBlock(num_radial, hidden_channels, act) - self.output_blocks = torch.nn.ModuleList( + self.output_blocks = nn.ModuleList( [ OutputPPBlock( num_radial, @@ -270,7 +274,7 @@ def __init__( ] ) - self.interaction_blocks = torch.nn.ModuleList( + self.interaction_blocks = nn.ModuleList( [ InteractionPPBlock( hidden_channels, @@ -330,13 +334,42 @@ def forward(self, z, pos, batch=None): raise NotImplementedError +@registry.register_model("dimenetplusplus_energy_and_force_head") +class DimeNetPlusPlusWrapEnergyAndForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.regress_forces = backbone.regress_forces + + @conditional_grad(torch.enable_grad()) + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + outputs = { + "energy": ( + emb["P"].sum(dim=0) + if data.batch is None + else scatter(emb["P"], data.batch, dim=0) + ) + } + if self.regress_forces: + outputs["forces"] = ( + -1 + * ( + torch.autograd.grad( + outputs["energy"], + data.pos, + grad_outputs=torch.ones_like(outputs["energy"]), + create_graph=True, + )[0] + ) + ) + return outputs + + @registry.register_model("dimenetplusplus") -class DimeNetPlusPlusWrap(DimeNetPlusPlus, BaseModel): +class DimeNetPlusPlusWrap(DimeNetPlusPlus, GraphModelMixin): def __init__( self, - num_atoms: int, - bond_feat_dim: int, # not used - num_targets: int, use_pbc: bool = True, regress_forces: bool = True, hidden_channels: int = 128, @@ -353,7 +386,6 @@ def __init__( num_after_skip: int = 2, num_output_layers: int = 3, ) -> None: - self.num_targets = num_targets self.regress_forces = regress_forces self.use_pbc = use_pbc self.cutoff = cutoff @@ -362,7 +394,7 @@ def __init__( super().__init__( hidden_channels=hidden_channels, - out_channels=num_targets, + out_channels=1, num_blocks=num_blocks, int_emb_size=int_emb_size, basis_emb_size=basis_emb_size, @@ -380,22 +412,15 @@ def __init__( def _forward(self, data): pos = data.pos batch = data.batch - ( - edge_index, - dist, - _, - cell_offsets, - offsets, - neighbors, - ) = self.generate_graph(data) - - data.edge_index = edge_index - data.cell_offsets = cell_offsets - data.neighbors = neighbors - j, i = edge_index + graph = self.generate_graph(data) + + data.edge_index = graph.edge_index + data.cell_offsets = graph.cell_offsets + data.neighbors = graph.neighbors + j, i = graph.edge_index _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( - edge_index, + graph.edge_index, data.cell_offsets, num_nodes=data.atomic_numbers.size(0), ) @@ -405,8 +430,8 @@ def _forward(self, data): pos_j = pos[idx_j].detach() if self.use_pbc: pos_ji, pos_kj = ( - pos[idx_j].detach() - pos_i + offsets[idx_ji], - pos[idx_k].detach() - pos_j + offsets[idx_kj], + pos[idx_j].detach() - pos_i + graph.offset_distances[idx_ji], + pos[idx_k].detach() - pos_j + graph.offset_distances[idx_kj], ) else: pos_ji, pos_kj = ( @@ -418,8 +443,8 @@ def _forward(self, data): b = torch.cross(pos_ji, pos_kj).norm(dim=-1) angle = torch.atan2(b, a) - rbf = self.rbf(dist) - sbf = self.sbf(dist, angle, idx_kj) + rbf = self.rbf(graph.edge_distance) + sbf = self.sbf(graph.edge_distance, angle, idx_kj) # Embedding block. x = self.emb(data.atomic_numbers.long(), rbf, i, j) @@ -459,3 +484,57 @@ def forward(self, data): @property def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) + + +@registry.register_model("dimenetplusplus_backbone") +class DimeNetPlusPlusWrapBackbone(DimeNetPlusPlusWrap, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + if self.regress_forces: + data.pos.requires_grad_(True) + pos = data.pos + graph = self.generate_graph(data) + data.edge_index = graph.edge_index + data.cell_offsets = graph.cell_offsets + data.neighbors = graph.neighbors + j, i = graph.edge_index + + _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( + graph.edge_index, + data.cell_offsets, + num_nodes=data.atomic_numbers.size(0), + ) + + # Calculate angles. + pos_i = pos[idx_i].detach() + pos_j = pos[idx_j].detach() + if self.use_pbc: + pos_ji, pos_kj = ( + pos[idx_j].detach() - pos_i + graph.offset_distances[idx_ji], + pos[idx_k].detach() - pos_j + graph.offset_distances[idx_kj], + ) + else: + pos_ji, pos_kj = ( + pos[idx_j].detach() - pos_i, + pos[idx_k].detach() - pos_j, + ) + + a = (pos_ji * pos_kj).sum(dim=-1) + b = torch.cross(pos_ji, pos_kj).norm(dim=-1) + angle = torch.atan2(b, a) + + rbf = self.rbf(graph.edge_distance) + sbf = self.sbf(graph.edge_distance, angle, idx_kj) + + # Embedding block. + x = self.emb(data.atomic_numbers.long(), rbf, i, j) + P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0)) + + # Interaction blocks. + for interaction_block, output_block in zip( + self.interaction_blocks, self.output_blocks[1:] + ): + x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) + P += output_block(x, rbf, i, num_nodes=pos.size(0)) + + return {"P": P, "edge_embedding": x, "edge_idx": i} diff --git a/src/fairchem/core/models/equiformer_v2/__init__.py b/src/fairchem/core/models/equiformer_v2/__init__.py index 424b64f9e..720f890f6 100644 --- a/src/fairchem/core/models/equiformer_v2/__init__.py +++ b/src/fairchem/core/models/equiformer_v2/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from .equiformer_v2_oc20 import EquiformerV2_OC20 as EquiformerV2 +from .equiformer_v2 import EquiformerV2 __all__ = ["EquiformerV2"] diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py similarity index 72% rename from src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py rename to src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 8edf81319..e2625eada 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -10,13 +10,15 @@ from fairchem.core.common import gp_utils from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.models.scn.smearing import GaussianSmearing with contextlib.suppress(ImportError): pass +import typing + from .edge_rot_mat import init_edge_rot_mat from .gaussian_rbf import GaussianRadialBasisLayer from .input_block import EdgeDegreeEmbedding @@ -42,13 +44,18 @@ TransBlockV2, ) +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch + + from fairchem.core.models.base import GraphData + # Statistics of IS2RE 100K _AVG_NUM_NODES = 77.81317 _AVG_DEGREE = 23.395238876342773 # IS2RE: 100k, max_radius = 5, max_neighbors = 100 @registry.register_model("equiformer_v2") -class EquiformerV2_OC20(BaseModel): +class EquiformerV2(nn.Module, GraphModelMixin): """ Equiformer with graph attention built upon SO(2) convolution and feedforward network built upon S2 activation @@ -108,9 +115,6 @@ class EquiformerV2_OC20(BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, # not used use_pbc: bool = True, regress_forces: bool = True, otf_graph: bool = True, @@ -436,23 +440,12 @@ def forward(self, data): self.dtype = data.pos.dtype self.device = data.pos.device atomic_numbers = data.atomic_numbers.long() - - ( - edge_index, - edge_distance, - edge_distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph( + graph = self.generate_graph( data, enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, ) - data_batch_full = data.batch data_batch = data.batch - atomic_numbers_full = atomic_numbers - node_offset = 0 if gp_utils.initialized(): ( atomic_numbers, @@ -462,12 +455,17 @@ def forward(self, data): edge_distance, edge_distance_vec, ) = self._init_gp_partitions( - atomic_numbers_full, - data_batch_full, - edge_index, - edge_distance, - edge_distance_vec, + graph.atomic_numbers_full, + graph.batch_full, + graph.edge_index, + graph.edge_distance, + graph.edge_distance_vec, ) + graph.node_offset = node_offset + graph.edge_index = edge_index + graph.edge_distance = edge_distance + graph.edge_distance_vec = edge_distance_vec + ############################################################### # Entering Graph Parallel Region # after this point, if using gp, then node, edge tensors are split @@ -485,7 +483,9 @@ def forward(self, data): ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat(data, edge_index, edge_distance_vec) + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) # Initialize the WignerD matrices and other values for spherical harmonic calculations for i in range(self.num_resolutions): @@ -496,7 +496,6 @@ def forward(self, data): ############################################################### # Init per node representations using an atomic number based embedding - offset = 0 x = SO3_Embedding( len(atomic_numbers), self.lmax_list, @@ -519,27 +518,27 @@ def forward(self, data): offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) # Edge encoding (distance and atom edge) - edge_distance = self.distance_expansion(edge_distance) + graph.edge_distance = self.distance_expansion(graph.edge_distance) if self.share_atom_edge_embedding and self.use_atom_edge_embedding: - source_element = atomic_numbers_full[ - edge_index[0] + source_element = graph.atomic_numbers_full[ + graph.edge_index[0] ] # Source atom atomic number - target_element = atomic_numbers_full[ - edge_index[1] + target_element = graph.atomic_numbers_full[ + graph.edge_index[1] ] # Target atom atomic number source_embedding = self.source_embedding(source_element) target_embedding = self.target_embedding(target_element) - edge_distance = torch.cat( - (edge_distance, source_embedding, target_embedding), dim=1 + graph.edge_distance = torch.cat( + (graph.edge_distance, source_embedding, target_embedding), dim=1 ) # Edge-degree embedding edge_degree = self.edge_degree_embedding( - atomic_numbers_full, - edge_distance, - edge_index, + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, len(atomic_numbers), - node_offset, + graph.node_offset, ) x.embedding = x.embedding + edge_degree.embedding @@ -550,11 +549,11 @@ def forward(self, data): for i in range(self.num_layers): x = self.blocks[i]( x, # SO3_Embedding - atomic_numbers_full, - edge_distance, - edge_index, + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, batch=data_batch, # for GraphDropPath - node_offset=node_offset, + node_offset=graph.node_offset, ) # Final layer norm @@ -572,7 +571,7 @@ def forward(self, data): device=node_energy.device, dtype=node_energy.dtype, ) - energy.index_add_(0, data_batch_full, node_energy.view(-1)) + energy.index_add_(0, graph.batch_full, node_energy.view(-1)) energy = energy / self.avg_num_nodes # Add the per-atom linear references to the energy. @@ -594,8 +593,8 @@ def forward(self, data): with torch.cuda.amp.autocast(False): energy = energy.to(self.energy_lin_ref.dtype).index_add( 0, - data_batch_full, - self.energy_lin_ref[atomic_numbers_full], + graph.batch_full, + self.energy_lin_ref[graph.atomic_numbers_full], ) outputs = {"energy": energy} @@ -605,10 +604,10 @@ def forward(self, data): if self.regress_forces: forces = self.force_block( x, - atomic_numbers_full, - edge_distance, - edge_index, - node_offset=node_offset, + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + node_offset=graph.node_offset, ) forces = forces.embedding.narrow(1, 1, 3) forces = forces.view(-1, 3).contiguous() @@ -678,3 +677,209 @@ def no_weight_decay(self) -> set: no_wd_list.append(global_parameter_name) return set(no_wd_list) + + +@registry.register_model("equiformer_v2_backbone") +class EquiformerV2Backbone(EquiformerV2, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + self.batch_size = len(data.natoms) + self.dtype = data.pos.dtype + self.device = data.pos.device + atomic_numbers = data.atomic_numbers.long() + graph = self.generate_graph( + data, + enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, + ) + + data_batch = data.batch + if gp_utils.initialized(): + ( + atomic_numbers, + data_batch, + node_offset, + edge_index, + edge_distance, + edge_distance_vec, + ) = self._init_gp_partitions( + graph.atomic_numbers_full, + graph.batch_full, + graph.edge_index, + graph.edge_distance, + graph.edge_distance_vec, + ) + graph.node_offset = node_offset + graph.edge_index = edge_index + graph.edge_distance = edge_distance + graph.edge_distance_vec = edge_distance_vec + + ############################################################### + # Entering Graph Parallel Region + # after this point, if using gp, then node, edge tensors are split + # across the graph parallel ranks, some full tensors such as + # atomic_numbers_full are required because we need to index into the + # full graph when computing edge embeddings or reducing nodes from neighbors + # + # all tensors that do not have the suffix "_full" refer to the partial tensors. + # if not using gp, the full values are equal to the partial values + # ie: atomic_numbers_full == atomic_numbers + ############################################################### + + ############################################################### + # Initialize data structures + ############################################################### + + # Compute 3x3 rotation matrix per edge + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) + + # Initialize the WignerD matrices and other values for spherical harmonic calculations + for i in range(self.num_resolutions): + self.SO3_rotation[i].set_wigner(edge_rot_mat) + + ############################################################### + # Initialize node embeddings + ############################################################### + + # Init per node representations using an atomic number based embedding + x = SO3_Embedding( + len(atomic_numbers), + self.lmax_list, + self.sphere_channels, + self.device, + self.dtype, + ) + + offset_res = 0 + offset = 0 + # Initialize the l = 0, m = 0 coefficients for each resolution + for i in range(self.num_resolutions): + if self.num_resolutions == 1: + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers) + else: + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ + :, offset : offset + self.sphere_channels + ] + offset = offset + self.sphere_channels + offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + + # Edge encoding (distance and atom edge) + graph.edge_distance = self.distance_expansion(graph.edge_distance) + if self.share_atom_edge_embedding and self.use_atom_edge_embedding: + source_element = graph.atomic_numbers_full[ + graph.edge_index[0] + ] # Source atom atomic number + target_element = graph.atomic_numbers_full[ + graph.edge_index[1] + ] # Target atom atomic number + source_embedding = self.source_embedding(source_element) + target_embedding = self.target_embedding(target_element) + graph.edge_distance = torch.cat( + (graph.edge_distance, source_embedding, target_embedding), dim=1 + ) + + # Edge-degree embedding + edge_degree = self.edge_degree_embedding( + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + len(atomic_numbers), + graph.node_offset, + ) + x.embedding = x.embedding + edge_degree.embedding + + ############################################################### + # Update spherical node embeddings + ############################################################### + + for i in range(self.num_layers): + x = self.blocks[i]( + x, # SO3_Embedding + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + batch=data_batch, # for GraphDropPath + node_offset=graph.node_offset, + ) + + # Final layer norm + x.embedding = self.norm(x.embedding) + + return {"node_embedding": x, "graph": graph} + + +@registry.register_model("equiformer_v2_energy_head") +class EquiformerV2EnergyHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.avg_num_nodes = backbone.avg_num_nodes + self.energy_block = FeedForwardNetwork( + backbone.sphere_channels, + backbone.ffn_hidden_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_grid, + backbone.ffn_activation, + backbone.use_gate_act, + backbone.use_grid_mlp, + backbone.use_sep_s2_act, + ) + + def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): + node_energy = self.energy_block(emb["node_embedding"]) + node_energy = node_energy.embedding.narrow(1, 0, 1) + if gp_utils.initialized(): + node_energy = gp_utils.gather_from_model_parallel_region(node_energy, dim=0) + energy = torch.zeros( + len(data.natoms), + device=node_energy.device, + dtype=node_energy.dtype, + ) + energy.index_add_(0, data.batch, node_energy.view(-1)) + return {"energy": energy / self.avg_num_nodes} + + +@registry.register_model("equiformer_v2_force_head") +class EquiformerV2ForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + self.force_block = SO2EquivariantGraphAttention( + backbone.sphere_channels, + backbone.attn_hidden_channels, + backbone.num_heads, + backbone.attn_alpha_channels, + backbone.attn_value_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_rotation, + backbone.mappingReduced, + backbone.SO3_grid, + backbone.max_num_elements, + backbone.edge_channels_list, + backbone.block_use_atom_edge_embedding, + backbone.use_m_share_rad, + backbone.attn_activation, + backbone.use_s2_act_attn, + backbone.use_attn_renorm, + backbone.use_gate_act, + backbone.use_sep_s2_act, + alpha_drop=0.0, + ) + + def forward(self, data: Batch, emb: dict[str, torch.Tensor]): + forces = self.force_block( + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + node_offset=emb["graph"].node_offset, + ) + forces = forces.embedding.narrow(1, 1, 3) + forces = forces.view(-1, 3).contiguous() + if gp_utils.initialized(): + forces = gp_utils.gather_from_model_parallel_region(forces, dim=0) + return {"forces": forces} diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 0ec66b9db..dfa872c39 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -10,13 +10,17 @@ import contextlib import logging import time +import typing import torch import torch.nn as nn +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch + from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.models.escn.so3 import ( CoefficientMapping, SO3_Embedding, @@ -36,7 +40,7 @@ @registry.register_model("escn") -class eSCN(BaseModel): +class eSCN(nn.Module, GraphModelMixin): """Equivariant Spherical Channel Network Paper: Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs @@ -64,9 +68,6 @@ class eSCN(BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, # not used use_pbc: bool = True, regress_forces: bool = True, otf_graph: bool = False, @@ -79,7 +80,6 @@ def __init__( sphere_channels: int = 128, hidden_channels: int = 256, edge_channels: int = 128, - use_grid: bool = True, num_sphere_samples: int = 128, distance_function: str = "gaussian", basis_width_scalar: float = 1.0, @@ -232,22 +232,16 @@ def forward(self, data): start_time = time.time() atomic_numbers = data.atomic_numbers.long() num_atoms = len(atomic_numbers) - - ( - edge_index, - edge_distance, - edge_distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) ############################################################### # Initialize data structures ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat(data, edge_index, edge_distance_vec) + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) # Initialize the WignerD matrices and other values for spherical harmonic calculations self.SO3_edge_rot = nn.ModuleList() @@ -290,8 +284,8 @@ def forward(self, data): x_message = self.layer_blocks[i]( x, atomic_numbers, - edge_distance, - edge_index, + graph.edge_distance, + graph.edge_index, self.SO3_edge_rot, mappingReduced, ) @@ -304,8 +298,8 @@ def forward(self, data): x = self.layer_blocks[i]( x, atomic_numbers, - edge_distance, - edge_index, + graph.edge_distance, + graph.edge_index, self.SO3_edge_rot, mappingReduced, ) @@ -421,6 +415,149 @@ def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) +@registry.register_model("escn_backbone") +class eSCNBackbone(eSCN, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + device = data.pos.device + self.batch_size = len(data.natoms) + self.dtype = data.pos.dtype + + atomic_numbers = data.atomic_numbers.long() + num_atoms = len(atomic_numbers) + + graph = self.generate_graph(data) + + ############################################################### + # Initialize data structures + ############################################################### + + # Compute 3x3 rotation matrix per edge + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) + + # Initialize the WignerD matrices and other values for spherical harmonic calculations + self.SO3_edge_rot = nn.ModuleList() + for i in range(self.num_resolutions): + self.SO3_edge_rot.append(SO3_Rotation(edge_rot_mat, self.lmax_list[i])) + + ############################################################### + # Initialize node embeddings + ############################################################### + + # Init per node representations using an atomic number based embedding + offset = 0 + x = SO3_Embedding( + num_atoms, + self.lmax_list, + self.sphere_channels, + device, + self.dtype, + ) + + offset_res = 0 + offset = 0 + # Initialize the l=0,m=0 coefficients for each resolution + for i in range(self.num_resolutions): + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ + :, offset : offset + self.sphere_channels + ] + offset = offset + self.sphere_channels + offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + + # This can be expensive to compute (not implemented efficiently), so only do it once and pass it along to each layer + mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list, device) + + ############################################################### + # Update spherical node embeddings + ############################################################### + + for i in range(self.num_layers): + if i > 0: + x_message = self.layer_blocks[i]( + x, + atomic_numbers, + graph.edge_distance, + graph.edge_index, + self.SO3_edge_rot, + mappingReduced, + ) + + # Residual layer for all layers past the first + x.embedding = x.embedding + x_message.embedding + + else: + # No residual for the first layer + x = self.layer_blocks[i]( + x, + atomic_numbers, + graph.edge_distance, + graph.edge_index, + self.SO3_edge_rot, + mappingReduced, + ) + + # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. + # These values are fed into the output blocks. + x_pt = torch.tensor([], device=device) + offset = 0 + # Compute the embedding values at every sampled point on the sphere + for i in range(self.num_resolutions): + num_coefficients = int((x.lmax_list[i] + 1) ** 2) + x_pt = torch.cat( + [ + x_pt, + torch.einsum( + "abc, pb->apc", + x.embedding[:, offset : offset + num_coefficients], + self.sphharm_weights[i], + ).contiguous(), + ], + dim=2, + ) + offset = offset + num_coefficients + + x_pt = x_pt.view(-1, self.sphere_channels_all) + + return {"sphere_values": x_pt, "sphere_points": self.sphere_points} + + +@registry.register_model("escn_energy_head") +class eSCNEnergyHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + # Output blocks for energy and forces + self.energy_block = EnergyBlock( + backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act + ) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + node_energy = self.energy_block(emb["sphere_values"]) + energy = torch.zeros(len(data.natoms), device=data.pos.device) + energy.index_add_(0, data.batch, node_energy.view(-1)) + # Scale energy to help balance numerical precision w.r.t. forces + return {"energy": energy * 0.001} + + +@registry.register_model("escn_force_head") +class eSCNForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + self.force_block = ForceBlock( + backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act + ) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + return {"forces": self.force_block(emb["sphere_values"], emb["sphere_points"])} + + class LayerBlock(torch.nn.Module): """ Layer block: Perform one layer (message passing and aggregation) of the GNN diff --git a/src/fairchem/core/models/gemnet/gemnet.py b/src/fairchem/core/models/gemnet/gemnet.py index e719c219b..59b3eda08 100644 --- a/src/fairchem/core/models/gemnet/gemnet.py +++ b/src/fairchem/core/models/gemnet/gemnet.py @@ -7,14 +7,20 @@ from __future__ import annotations +import typing + import numpy as np import torch +import torch.nn as nn + +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch from torch_scatter import scatter from torch_sparse import SparseTensor from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.modules.scaling.compat import load_scales_compat from .layers.atom_update_block import OutputBlock @@ -28,17 +34,12 @@ @registry.register_model("gemnet_t") -class GemNetT(BaseModel): +class GemNetT(nn.Module, GraphModelMixin): """ GemNet-T, triplets-only variant of GemNet Parameters ---------- - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets: int - Number of prediction targets. - num_spherical: int Controls maximum frequency. num_radial: int @@ -94,9 +95,6 @@ class GemNetT(BaseModel): def __init__( self, - num_atoms: int | None, - bond_feat_dim: int, - num_targets: int, num_spherical: int, num_radial: int, num_blocks: int, @@ -132,7 +130,6 @@ def __init__( if rbf is None: rbf = {"name": "gaussian"} super().__init__() - self.num_targets = num_targets assert num_blocks > 0 self.num_blocks = num_blocks self.extensive = extensive @@ -235,7 +232,7 @@ def __init__( emb_size_edge=emb_size_edge, emb_size_rbf=emb_size_rbf, nHidden=num_atom, - num_targets=num_targets, + num_targets=1, activation=activation, output_init=output_init, direct_forces=direct_forces, @@ -421,18 +418,10 @@ def select_edges( def generate_interaction_graph(self, data): num_atoms = data.atomic_numbers.size(0) - - ( - edge_index, - D_st, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) # These vectors actually point in the opposite direction. # But we want to use col as idx_t for efficient aggregation. - V_st = -distance_vec / D_st[:, None] + V_st = -graph.edge_distance_vec / graph.edge_distance[:, None] # Mask interaction edges if required if self.otf_graph or np.isclose(self.cutoff, 6): @@ -447,10 +436,10 @@ def generate_interaction_graph(self, data): V_st, ) = self.select_edges( data=data, - edge_index=edge_index, - cell_offsets=cell_offsets, - neighbors=neighbors, - edge_dist=D_st, + edge_index=graph.edge_index, + cell_offsets=graph.cell_offsets, + neighbors=graph.neighbors, + edge_dist=graph.edge_distance, edge_vector=V_st, cutoff=select_cutoff, ) @@ -530,7 +519,7 @@ def forward(self, data): rbf_out = self.mlp_rbf_out(rbf) E_t, F_st = self.out_blocks[0](h, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) for i in range(self.num_blocks): # Interaction block @@ -549,7 +538,7 @@ def forward(self, data): ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) E, F = self.out_blocks[i + 1](h, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) F_st += F E_t += E @@ -557,11 +546,11 @@ def forward(self, data): if self.extensive: E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="add" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) else: E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) outputs = {"energy": E_t} @@ -569,30 +558,18 @@ def forward(self, data): if self.direct_forces: # map forces in edge directions F_st_vec = F_st[:, :, None] * V_st[:, None, :] - # (nEdges, num_targets, 3) + # (nEdges, 1, 3) F_t = scatter( F_st_vec, idx_t, dim=0, dim_size=data.atomic_numbers.size(0), reduce="add", - ) # (nAtoms, num_targets, 3) + ) # (nAtoms, 1, 3) F_t = F_t.squeeze(1) # (nAtoms, 3) else: - if self.num_targets > 1: - forces = [] - for i in range(self.num_targets): - # maybe this can be solved differently - forces += [ - -torch.autograd.grad( - E_t[:, i].sum(), pos, create_graph=True - )[0] - ] - F_t = torch.stack(forces, dim=1) - # (nAtoms, num_targets, 3) - else: - F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] - # (nAtoms, 3) + F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] + # (nAtoms, 3) outputs["forces"] = F_t @@ -601,3 +578,129 @@ def forward(self, data): @property def num_params(self): return sum(p.numel() for p in self.parameters()) + + +@registry.register_model("gemnet_t_backbone") +class GemNetTBackbone(GemNetT, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + pos = data.pos + atomic_numbers = data.atomic_numbers.long() + + if self.regress_forces and not self.direct_forces: + pos.requires_grad_(True) + + ( + edge_index, + neighbors, + D_st, + V_st, + id_swap, + id3_ba, + id3_ca, + id3_ragged_idx, + ) = self.generate_interaction_graph(data) + idx_s, idx_t = edge_index + + # Calculate triplet angles + cosφ_cab = inner_product_normalized(V_st[id3_ca], V_st[id3_ba]) + rad_cbf3, cbf3 = self.cbf_basis3(D_st, cosφ_cab, id3_ca) + + rbf = self.radial_basis(D_st) + + # Embedding block + h = self.atom_emb(atomic_numbers) + # (nAtoms, emb_size_atom) + m = self.edge_emb(h, rbf, idx_s, idx_t) # (nEdges, emb_size_edge) + + rbf3 = self.mlp_rbf3(rbf) + cbf3 = self.mlp_cbf3(rad_cbf3, cbf3, id3_ca, id3_ragged_idx) + + rbf_h = self.mlp_rbf_h(rbf) + rbf_out = self.mlp_rbf_out(rbf) + + E_t, F_st = self.out_blocks[0](h, m, rbf_out, idx_t) + # (nAtoms, 1), (nEdges, 1) + + for i in range(self.num_blocks): + # Interaction block + h, m = self.int_blocks[i]( + h=h, + m=m, + rbf3=rbf3, + cbf3=cbf3, + id3_ragged_idx=id3_ragged_idx, + id_swap=id_swap, + id3_ba=id3_ba, + id3_ca=id3_ca, + rbf_h=rbf_h, + idx_s=idx_s, + idx_t=idx_t, + ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + + E, F = self.out_blocks[i + 1](h, m, rbf_out, idx_t) + # (nAtoms, 1), (nEdges, 1) + F_st += F + E_t += E + return { + "F_st": F_st, + "E_t": E_t, + "edge_vec": V_st, + "edge_idx": idx_t, + "node_embedding": h, + "edge_embedding": m, + } + + +@registry.register_model("gemnet_t_energy_and_grad_force_head") +class GemNetTEnergyAndGradForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.extensive = backbone.extensive + self.regress_forces = backbone.regress_forces + self.direct_forces = backbone.direct_forces + + @conditional_grad(torch.enable_grad()) + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + nMolecules = torch.max(data.batch) + 1 + if self.extensive: + E_t = scatter( + emb["E_t"], data.batch, dim=0, dim_size=nMolecules, reduce="add" + ) # (nMolecules, 1) + else: + E_t = scatter( + emb["E_t"], data.batch, dim=0, dim_size=nMolecules, reduce="mean" + ) # (nMolecules, 1) + + outputs = {"energy": E_t} + + if self.regress_forces and not self.direct_forces: + outputs["forces"] = -torch.autograd.grad( + E_t.sum(), data.pos, create_graph=True + )[0] + # (nAtoms, 3) + return outputs + + +@registry.register_model("gemnet_t_force_head") +class GemNetTForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.direct_forces = backbone.direct_forces + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + # map forces in edge directions + F_st_vec = emb["F_st"][:, :, None] * emb["edge_vec"][:, None, :] + # (nEdges, 1, 3) + F_t = scatter( + F_st_vec, + emb["edge_idx"], + dim=0, + dim_size=data.atomic_numbers.size(0), + reduce="add", + ) # (nAtoms, 1, 3) + return {"forces": F_t.squeeze(1)} # (nAtoms, 3) diff --git a/src/fairchem/core/models/gemnet_gp/gemnet.py b/src/fairchem/core/models/gemnet_gp/gemnet.py index 81fbd4069..a75756dcc 100644 --- a/src/fairchem/core/models/gemnet_gp/gemnet.py +++ b/src/fairchem/core/models/gemnet_gp/gemnet.py @@ -9,13 +9,14 @@ import numpy as np import torch +from torch import nn from torch_scatter import scatter from torch_sparse import SparseTensor from fairchem.core.common import gp_utils from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import GraphModelMixin from fairchem.core.modules.scaling.compat import load_scales_compat from .layers.atom_update_block import OutputBlock @@ -29,17 +30,12 @@ @registry.register_model("gp_gemnet_t") -class GraphParallelGemNetT(BaseModel): +class GraphParallelGemNetT(nn.Module, GraphModelMixin): """ GemNet-T, triplets-only variant of GemNet Parameters ---------- - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets: int - Number of prediction targets. - num_spherical: int Controls maximum frequency. num_radial: int @@ -95,9 +91,6 @@ class GraphParallelGemNetT(BaseModel): def __init__( self, - num_atoms: int | None, - bond_feat_dim: int, - num_targets: int, num_spherical: int, num_radial: int, num_blocks: int, @@ -134,7 +127,6 @@ def __init__( if rbf is None: rbf = {"name": "gaussian"} super().__init__() - self.num_targets = num_targets assert num_blocks > 0 self.num_blocks = num_blocks self.extensive = extensive @@ -239,7 +231,7 @@ def __init__( emb_size_edge=emb_size_edge, emb_size_rbf=emb_size_rbf, nHidden=num_atom, - num_targets=num_targets, + num_targets=1, activation=activation, output_init=output_init, direct_forces=direct_forces, @@ -415,18 +407,10 @@ def select_edges( def generate_interaction_graph(self, data): num_atoms = data.atomic_numbers.size(0) - - ( - edge_index, - D_st, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) # These vectors actually point in the opposite direction. # But we want to use col as idx_t for efficient aggregation. - V_st = -distance_vec / D_st[:, None] + V_st = -graph.distance_vec / graph.edge_distance[:, None] # Mask interaction edges if required if self.otf_graph or np.isclose(self.cutoff, 6): @@ -441,10 +425,10 @@ def generate_interaction_graph(self, data): V_st, ) = self.select_edges( data=data, - edge_index=edge_index, - cell_offsets=cell_offsets, - neighbors=neighbors, - edge_dist=D_st, + edge_index=graph.edge_index, + cell_offsets=graph.cell_offsets, + neighbors=graph.neighbors, + edge_dist=graph.edge_distance, edge_vector=V_st, cutoff=select_cutoff, ) @@ -563,7 +547,7 @@ def forward(self, data): rbf_out = self.mlp_rbf_out(rbf) E_t, F_st = self.out_blocks[0](nAtoms, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) for i in range(self.num_blocks): # Interaction block @@ -585,7 +569,7 @@ def forward(self, data): ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) E, F = self.out_blocks[i + 1](nAtoms, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) F_st += F E_t += E @@ -601,41 +585,29 @@ def forward(self, data): E_t = gp_utils.gather_from_model_parallel_region(E_t, dim=0) E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="add" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) else: E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) outputs = {"energy": E_t} if self.regress_forces: if self.direct_forces: # map forces in edge directions F_st_vec = F_st[:, :, None] * V_st[:, None, :] - # (nEdges, num_targets, 3) + # (nEdges, 1, 3) F_t = scatter( F_st_vec, idx_t_full, dim=0, dim_size=data.atomic_numbers.size(0), reduce="add", - ) # (nAtoms, num_targets, 3) + ) # (nAtoms, 1, 3) F_t = F_t.squeeze(1) # (nAtoms, 3) else: - if self.num_targets > 1: - forces = [] - for i in range(self.num_targets): - # maybe this can be solved differently - forces += [ - -torch.autograd.grad( - E_t[:, i].sum(), pos, create_graph=True - )[0] - ] - F_t = torch.stack(forces, dim=1) - # (nAtoms, num_targets, 3) - else: - F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] - # (nAtoms, 3) + F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] + # (nAtoms, 3) outputs["forces"] = F_t diff --git a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py index e1176d00c..0aea3d81b 100644 --- a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py +++ b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py @@ -7,9 +7,11 @@ from __future__ import annotations import logging +import typing import numpy as np import torch +import torch.nn as nn from torch_scatter import segment_coo from fairchem.core.common.registry import registry @@ -18,7 +20,7 @@ get_max_neighbors_mask, scatter_det, ) -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.modules.scaling.compat import load_scales_compat from .initializers import get_initializer @@ -40,17 +42,15 @@ repeat_blocks, ) +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch + @registry.register_model("gemnet_oc") -class GemNetOC(BaseModel): +class GemNetOC(nn.Module, GraphModelMixin): """ Arguments --------- - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets: int - Number of prediction targets. - num_spherical: int Controls maximum frequency. num_radial: int @@ -179,9 +179,6 @@ class GemNetOC(BaseModel): def __init__( self, - num_atoms: int | None, - bond_feat_dim: int, - num_targets: int, num_spherical: int, num_radial: int, num_blocks: int, @@ -249,11 +246,11 @@ def __init__( super().__init__() if len(kwargs) > 0: logging.warning(f"Unrecognized arguments: {list(kwargs.keys())}") - self.num_targets = num_targets assert num_blocks > 0 self.num_blocks = num_blocks self.extensive = extensive + self.activation = activation self.atom_edge_interaction = atom_edge_interaction self.edge_atom_interaction = edge_atom_interaction self.atom_interaction = atom_interaction @@ -357,7 +354,7 @@ def __init__( for _ in range(num_global_out_layers) ] self.out_mlp_E = torch.nn.Sequential(*out_mlp_E) - self.out_energy = Dense(emb_size_atom, num_targets, bias=False, activation=None) + self.out_energy = Dense(emb_size_atom, 1, bias=False, activation=None) if direct_forces: out_mlp_F = [ Dense( @@ -373,9 +370,7 @@ def __init__( for _ in range(num_global_out_layers) ] self.out_mlp_F = torch.nn.Sequential(*out_mlp_F) - self.out_forces = Dense( - emb_size_edge, num_targets, bias=False, activation=None - ) + self.out_forces = Dense(emb_size_edge, 1, bias=False, activation=None) out_initializer = get_initializer(output_init) self.out_energy.reset_parameters(out_initializer) @@ -870,15 +865,7 @@ def subselect_edges( def generate_graph_dict(self, data, cutoff, max_neighbors): """Generate a radius/nearest neighbor graph.""" otf_graph = cutoff > 6 or max_neighbors > 50 or self.otf_graph - - ( - edge_index, - edge_dist, - distance_vec, - cell_offsets, - _, # cell offset distances - num_neighbors, - ) = self.generate_graph( + graph = self.generate_graph( data, cutoff=cutoff, max_neighbors=max_neighbors, @@ -886,15 +873,15 @@ def generate_graph_dict(self, data, cutoff, max_neighbors): ) # These vectors actually point in the opposite direction. # But we want to use col as idx_t for efficient aggregation. - edge_vector = -distance_vec / edge_dist[:, None] - cell_offsets = -cell_offsets # a - c + offset + edge_vector = -graph.edge_distance_vec / graph.edge_distance[:, None] + cell_offsets = -graph.cell_offsets # a - c + offset graph = { - "edge_index": edge_index, - "distance": edge_dist, + "edge_index": graph.edge_index, + "distance": graph.edge_distance, "vector": edge_vector, "cell_offset": cell_offsets, - "num_neighbors": num_neighbors, + "num_neighbors": graph.neighbors, } # Mask interaction edges if required @@ -1285,11 +1272,11 @@ def forward(self, data): if self.extensive: E_t = scatter_det( E_t, batch, dim=0, dim_size=nMolecules, reduce="add" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) else: E_t = scatter_det( E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) E_t = E_t.squeeze(1) # (num_molecules) outputs = {"energy": E_t} @@ -1308,19 +1295,19 @@ def forward(self, data): dim=0, dim_size=int(nEdges / 2), reduce="mean", - ) # (nEdges/2, num_targets) - F_st = F_st[id_undir] # (nEdges, num_targets) + ) # (nEdges/2, 1) + F_st = F_st[id_undir] # (nEdges, 1) # map forces in edge directions F_st_vec = F_st[:, :, None] * main_graph["vector"][:, None, :] - # (nEdges, num_targets, 3) + # (nEdges, 1, 3) F_t = scatter_det( F_st_vec, idx_t, dim=0, dim_size=num_atoms, reduce="add", - ) # (nAtoms, num_targets, 3) + ) # (nAtoms, 1, 3) else: F_t = self.force_scaler.calc_forces_and_update(E_t, pos) @@ -1333,3 +1320,233 @@ def forward(self, data): @property def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) + + +@registry.register_model("gemnet_oc_backbone") +class GemNetOCBackbone(GemNetOC, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + pos = data.pos + atomic_numbers = data.atomic_numbers.long() + num_atoms = atomic_numbers.shape[0] + + if self.regress_forces and not self.direct_forces: + pos.requires_grad_(True) + + ( + main_graph, + a2a_graph, + a2ee2a_graph, + qint_graph, + id_swap, + trip_idx_e2e, + trip_idx_a2e, + trip_idx_e2a, + quad_idx, + ) = self.get_graphs_and_indices(data) + _, idx_t = main_graph["edge_index"] + + ( + basis_rad_raw, + basis_atom_update, + basis_output, + bases_qint, + bases_e2e, + bases_a2e, + bases_e2a, + basis_a2a_rad, + ) = self.get_bases( + main_graph=main_graph, + a2a_graph=a2a_graph, + a2ee2a_graph=a2ee2a_graph, + qint_graph=qint_graph, + trip_idx_e2e=trip_idx_e2e, + trip_idx_a2e=trip_idx_a2e, + trip_idx_e2a=trip_idx_e2a, + quad_idx=quad_idx, + num_atoms=num_atoms, + ) + + # Embedding block + h = self.atom_emb(atomic_numbers) + # (nAtoms, emb_size_atom) + m = self.edge_emb(h, basis_rad_raw, main_graph["edge_index"]) + # (nEdges, emb_size_edge) + + x_E, x_F = self.out_blocks[0](h, m, basis_output, idx_t) + # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + xs_E, xs_F = [x_E], [x_F] + + for i in range(self.num_blocks): + # Interaction block + h, m = self.int_blocks[i]( + h=h, + m=m, + bases_qint=bases_qint, + bases_e2e=bases_e2e, + bases_a2e=bases_a2e, + bases_e2a=bases_e2a, + basis_a2a_rad=basis_a2a_rad, + basis_atom_update=basis_atom_update, + edge_index_main=main_graph["edge_index"], + a2ee2a_graph=a2ee2a_graph, + a2a_graph=a2a_graph, + id_swap=id_swap, + trip_idx_e2e=trip_idx_e2e, + trip_idx_a2e=trip_idx_a2e, + trip_idx_e2a=trip_idx_e2a, + quad_idx=quad_idx, + ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + + x_E, x_F = self.out_blocks[i + 1](h, m, basis_output, idx_t) + # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + xs_E.append(x_E) + xs_F.append(x_F) + + return { + "xs_E": xs_E, + "xs_F": xs_F, + "edge_vec": main_graph["vector"], + "edge_idx": idx_t, + "num_neighbors": main_graph["num_neighbors"], + } + + +@registry.register_model("gemnet_oc_energy_and_grad_force_head") +class GemNetOCEnergyAndGradForceHead(nn.Module, HeadInterface): + def __init__( + self, + backbone: BackboneInterface, + num_global_out_layers: int, + output_init: str = "HeOrthogonal", + ): + super().__init__() + self.extensive = backbone.extensive + + self.regress_forces = backbone.regress_forces + self.direct_forces = backbone.direct_forces + self.force_scaler = backbone.force_scaler + + out_mlp_E = [ + Dense( + backbone.atom_emb.emb_size * (len(backbone.int_blocks) + 1), + backbone.atom_emb.emb_size, + activation=backbone.activation, + ) + ] + [ + ResidualLayer( + backbone.atom_emb.emb_size, + activation=backbone.activation, + ) + for _ in range(num_global_out_layers) + ] + self.out_mlp_E = torch.nn.Sequential(*out_mlp_E) + + self.out_energy = Dense( + backbone.atom_emb.emb_size, + 1, + bias=False, + activation=None, + ) + + out_initializer = get_initializer(output_init) + self.out_energy.reset_parameters(out_initializer) + + @conditional_grad(torch.enable_grad()) + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + # Global output block for final predictions + x_E = self.out_mlp_E(torch.cat(emb["xs_E"], dim=-1)) + with torch.cuda.amp.autocast(False): + E_t = self.out_energy(x_E.float()) + + nMolecules = torch.max(data.batch) + 1 + if self.extensive: + E_t = scatter_det( + E_t, data.batch, dim=0, dim_size=nMolecules, reduce="add" + ) # (nMolecules, 1) + else: + E_t = scatter_det( + E_t, data.batch, dim=0, dim_size=nMolecules, reduce="mean" + ) # (nMolecules, 1) + + outputs = {"energy": E_t.squeeze(1)} # (num_molecules) + + if self.regress_forces and not self.direct_forces: + F_t = self.force_scaler.calc_forces_and_update(outputs["energy"], data.pos) + outputs["forces"] = F_t.squeeze(1) + return outputs + + +@registry.register_model("gemnet_oc_force_head") +class GemNetOCForceHead(nn.Module, HeadInterface): + def __init__( + self, backbone, num_global_out_layers: int, output_init: str = "HeOrthogonal" + ): + super().__init__() + + self.direct_forces = backbone.direct_forces + self.forces_coupled = backbone.forces_coupled + + emb_size_edge = backbone.edge_emb.dense.linear.out_features + if self.direct_forces: + out_mlp_F = [ + Dense( + emb_size_edge * (len(backbone.int_blocks) + 1), + emb_size_edge, + activation=backbone.activation, + ) + ] + [ + ResidualLayer( + emb_size_edge, + activation=backbone.activation, + ) + for _ in range(num_global_out_layers) + ] + self.out_mlp_F = torch.nn.Sequential(*out_mlp_F) + self.out_forces = Dense( + emb_size_edge, + 1, + bias=False, + activation=None, + ) + out_initializer = get_initializer(output_init) + self.out_forces.reset_parameters(out_initializer) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + if self.direct_forces: + x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1)) + with torch.cuda.amp.autocast(False): + F_st = self.out_forces(x_F.float()) + + if self.forces_coupled: # enforce F_st = F_ts + nEdges = emb["edge_idx"].shape[0] + id_undir = repeat_blocks( + emb["num_neighbors"] // 2, + repeats=2, + continuous_indexing=True, + ) + F_st = scatter_det( + F_st, + id_undir, + dim=0, + dim_size=int(nEdges / 2), + reduce="mean", + ) # (nEdges/2, 1) + F_st = F_st[id_undir] # (nEdges, 1) + + # map forces in edge directions + F_st_vec = F_st[:, :, None] * emb["edge_vec"][:, None, :] + # (nEdges, 1, 3) + F_t = scatter_det( + F_st_vec, + emb["edge_idx"], + dim=0, + dim_size=data.atomic_numbers.long().shape[0], + reduce="add", + ) # (nAtoms, 1, 3) + return {"forces": F_t.squeeze(1)} # (num_atoms, 3) + return {} diff --git a/src/fairchem/core/models/painn/painn.py b/src/fairchem/core/models/painn/painn.py index 8843f02b2..ec9e9f465 100644 --- a/src/fairchem/core/models/painn/painn.py +++ b/src/fairchem/core/models/painn/painn.py @@ -32,15 +32,19 @@ from __future__ import annotations import math +import typing import torch from torch import nn + +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch from torch_geometric.nn import MessagePassing from torch_scatter import scatter, segment_coo from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.models.gemnet.layers.base_layers import ScaledSiLU from fairchem.core.models.gemnet.layers.embedding_block import AtomEmbedding from fairchem.core.models.gemnet.layers.radial_basis import RadialBasis @@ -51,7 +55,7 @@ @registry.register_model("painn") -class PaiNN(BaseModel): +class PaiNN(nn.Module, GraphModelMixin): r"""PaiNN model based on the description in Schütt et al. (2021): Equivariant message passing for the prediction of tensorial properties and molecular spectra, https://arxiv.org/abs/2102.03150. @@ -59,9 +63,6 @@ class PaiNN(BaseModel): def __init__( self, - num_atoms: int, - bond_feat_dim: int, - num_targets: int, hidden_channels: int = 512, num_layers: int = 6, num_rbf: int = 128, @@ -310,23 +311,16 @@ def symmetrize_edges( ) def generate_graph_values(self, data): - ( - edge_index, - edge_dist, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) # Unit vectors pointing from edge_index[1] to edge_index[0], # i.e., edge_index[0] - edge_index[1] divided by the norm. # make sure that the distances are not close to zero before dividing - mask_zero = torch.isclose(edge_dist, torch.tensor(0.0), atol=1e-6) - edge_dist[mask_zero] = 1.0e-6 - edge_vector = distance_vec / edge_dist[:, None] + mask_zero = torch.isclose(graph.edge_distance, torch.tensor(0.0), atol=1e-6) + graph.edge_distance[mask_zero] = 1.0e-6 + edge_vector = graph.edge_distance_vec / graph.edge_distance[:, None] - empty_image = neighbors == 0 + empty_image = graph.neighbors == 0 if torch.any(empty_image): raise ValueError( f"An image has no neighbors: id={data.id[empty_image]}, " @@ -342,11 +336,11 @@ def generate_graph_values(self, data): [edge_vector], id_swap, ) = self.symmetrize_edges( - edge_index, - cell_offsets, - neighbors, + graph.edge_index, + graph.cell_offsets, + graph.neighbors, data.batch, - [edge_dist], + [graph.edge_distance], [edge_vector], ) @@ -436,6 +430,50 @@ def __repr__(self) -> str: ) +@registry.register_model("painn_backbone") +class PaiNNBackbone(PaiNN, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data) -> dict[str, torch.Tensor]: + pos = data.pos + z = data.atomic_numbers.long() + + if self.regress_forces and not self.direct_forces: + pos = pos.requires_grad_(True) + + ( + edge_index, + neighbors, + edge_dist, + edge_vector, + id_swap, + ) = self.generate_graph_values(data) + + assert z.dim() == 1 + assert z.dtype == torch.long + + edge_rbf = self.radial_basis(edge_dist) # rbf * envelope + + x = self.atom_emb(z) + vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device) + + #### Interaction blocks ############################################### + + for i in range(self.num_layers): + dx, dvec = self.message_layers[i](x, vec, edge_index, edge_rbf, edge_vector) + + x = x + dx + vec = vec + dvec + x = x * self.inv_sqrt_2 + + dx, dvec = self.update_layers[i](x, vec) + + x = x + dx + vec = vec + dvec + x = getattr(self, "upd_out_scalar_scale_%d" % i)(x) + + return {"node_embedding": x, "node_vec": vec} + + class PaiNNMessage(MessagePassing): def __init__( self, @@ -625,3 +663,53 @@ def forward(self, x, v): x = self.act(x) return x, v + + +@registry.register_model("painn_energy_head") +class PaiNNEnergyHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + self.out_energy = nn.Sequential( + nn.Linear(backbone.hidden_channels, backbone.hidden_channels // 2), + ScaledSiLU(), + nn.Linear(backbone.hidden_channels // 2, 1), + ) + + nn.init.xavier_uniform_(self.out_energy[0].weight) + self.out_energy[0].bias.data.fill_(0) + nn.init.xavier_uniform_(self.out_energy[2].weight) + self.out_energy[2].bias.data.fill_(0) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + per_atom_energy = self.out_energy(emb["node_embedding"]).squeeze(1) + return {"energy": scatter(per_atom_energy, data.batch, dim=0)} + + +@registry.register_model("painn_force_head") +class PaiNNForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.direct_forces = backbone.direct_forces + + if self.direct_forces: + self.out_forces = PaiNNOutput(backbone.hidden_channels) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + if self.direct_forces: + forces = self.out_forces(emb["node_embedding"], emb["node_vec"]) + else: + forces = ( + -1 + * torch.autograd.grad( + emb["node_embedding"], + data.pos, + grad_outputs=torch.ones_like(emb["node_embedding"]), + create_graph=True, + )[0] + ) + return {"forces": forces} diff --git a/src/fairchem/core/models/schnet.py b/src/fairchem/core/models/schnet.py index 2f89c17e1..5ca70a354 100644 --- a/src/fairchem/core/models/schnet.py +++ b/src/fairchem/core/models/schnet.py @@ -13,11 +13,11 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import GraphModelMixin @registry.register_model("schnet") -class SchNetWrap(SchNet, BaseModel): +class SchNetWrap(SchNet, GraphModelMixin): r"""Wrapper around the continuous-filter convolutional neural network SchNet from the `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions" `_. Each layer uses interaction @@ -28,9 +28,6 @@ class SchNetWrap(SchNet, BaseModel): h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))), Args: - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets (int): Number of targets to predict. use_pbc (bool, optional): If set to :obj:`True`, account for periodic boundary conditions. (default: :obj:`True`) regress_forces (bool, optional): If set to :obj:`True`, predict forces by differentiating @@ -54,9 +51,6 @@ class SchNetWrap(SchNet, BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, use_pbc: bool = True, regress_forces: bool = True, otf_graph: bool = False, @@ -67,7 +61,7 @@ def __init__( cutoff: float = 10.0, readout: str = "add", ) -> None: - self.num_targets = num_targets + self.num_targets = 1 self.regress_forces = regress_forces self.use_pbc = use_pbc self.cutoff = cutoff @@ -88,25 +82,17 @@ def _forward(self, data): z = data.atomic_numbers.long() pos = data.pos batch = data.batch - - ( - edge_index, - edge_weight, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) if self.use_pbc: assert z.dim() == 1 assert z.dtype == torch.long - edge_attr = self.distance_expansion(edge_weight) + edge_attr = self.distance_expansion(graph.edge_distance) h = self.embedding(z) for interaction in self.interactions: - h = h + interaction(h, edge_index, edge_weight, edge_attr) + h = h + interaction(h, graph.edge_index, graph.edge_distance, edge_attr) h = self.lin1(h) h = self.act(h) diff --git a/src/fairchem/core/models/scn/scn.py b/src/fairchem/core/models/scn/scn.py index bf8454f21..84806e19e 100644 --- a/src/fairchem/core/models/scn/scn.py +++ b/src/fairchem/core/models/scn/scn.py @@ -18,7 +18,7 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import GraphModelMixin from fairchem.core.models.scn.sampling import CalcSpherePoints from fairchem.core.models.scn.smearing import ( GaussianSmearing, @@ -33,7 +33,7 @@ @registry.register_model("scn") -class SphericalChannelNetwork(BaseModel): +class SphericalChannelNetwork(nn.Module, GraphModelMixin): """Spherical Channel Network Paper: Spherical Channels for Modeling Atomic Interactions @@ -75,9 +75,6 @@ class SphericalChannelNetwork(BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, # not used use_pbc: bool = True, regress_forces: bool = True, otf_graph: bool = False, @@ -262,15 +259,7 @@ def _forward_helper(self, data): atomic_numbers = data.atomic_numbers.long() num_atoms = len(atomic_numbers) pos = data.pos - - ( - edge_index, - edge_distance, - edge_distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) ############################################################### # Initialize data structures @@ -278,12 +267,12 @@ def _forward_helper(self, data): # Calculate which message block each edge should use. Based on edge distance rank. edge_rank = self._rank_edge_distances( - edge_distance, edge_index, self.max_num_neighbors + graph.edge_distance, graph.edge_index, self.max_num_neighbors ) # Reorder edges so that they are grouped by distance rank (lowest to highest) last_cutoff = -0.1 - message_block_idx = torch.zeros(len(edge_distance), device=pos.device) + message_block_idx = torch.zeros(len(graph.edge_distance), device=pos.device) edge_distance_reorder = torch.tensor([], device=self.device) edge_index_reorder = torch.tensor([], device=self.device) edge_distance_vec_reorder = torch.tensor([], device=self.device) @@ -297,21 +286,21 @@ def _forward_helper(self, data): edge_distance_reorder = torch.cat( [ edge_distance_reorder, - torch.masked_select(edge_distance, mask), + torch.masked_select(graph.edge_distance, mask), ], dim=0, ) edge_index_reorder = torch.cat( [ edge_index_reorder, - torch.masked_select(edge_index, mask.view(1, -1).repeat(2, 1)).view( - 2, -1 - ), + torch.masked_select( + graph.edge_index, mask.view(1, -1).repeat(2, 1) + ).view(2, -1), ], dim=1, ) edge_distance_vec_mask = torch.masked_select( - edge_distance_vec, mask.view(-1, 1).repeat(1, 3) + graph.edge_distance_vec, mask.view(-1, 1).repeat(1, 3) ).view(-1, 3) edge_distance_vec_reorder = torch.cat( [edge_distance_vec_reorder, edge_distance_vec_mask], dim=0 diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 79a9c1ebe..8d5f61946 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -7,6 +7,7 @@ from __future__ import annotations +import copy import datetime import errno import logging @@ -481,19 +482,7 @@ def load_model(self) -> None: if distutils.is_master(): logging.info(f"Loading model: {self.config['model']}") - # TODO: depreicated, remove. - bond_feat_dim = None - bond_feat_dim = self.config["model_attributes"].get("num_gaussians", 50) - - loader = self.train_loader or self.val_loader or self.test_loader self.model = registry.get_model_class(self.config["model"])( - loader.dataset[0].x.shape[-1] - if loader - and hasattr(loader.dataset[0], "x") - and loader.dataset[0].x is not None - else None, - bond_feat_dim, - 1, **self.config["model_attributes"], ).to(self.device) diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 86bf979ef..e08e8e981 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -22,9 +22,32 @@ @pytest.fixture() def configs(): return { + "scn": Path("tests/core/models/test_configs/test_scn.yml"), "escn": Path("tests/core/models/test_configs/test_escn.yml"), - "gemnet": Path("tests/core/models/test_configs/test_gemnet.yml"), + "escn_hydra": Path("tests/core/models/test_configs/test_escn_hydra.yml"), + "schnet": Path("tests/core/models/test_configs/test_schnet.yml"), + "gemnet_dt": Path("tests/core/models/test_configs/test_gemnet_dt.yml"), + "gemnet_dt_hydra": Path( + "tests/core/models/test_configs/test_gemnet_dt_hydra.yml" + ), + "gemnet_dt_hydra_grad": Path( + "tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml" + ), + "gemnet_oc": Path("tests/core/models/test_configs/test_gemnet_oc.yml"), + "gemnet_oc_hydra": Path( + "tests/core/models/test_configs/test_gemnet_oc_hydra.yml" + ), + "gemnet_oc_hydra_grad": Path( + "tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml" + ), + "dimenet++": Path("tests/core/models/test_configs/test_dpp.yml"), + "dimenet++_hydra": Path("tests/core/models/test_configs/test_dpp_hydra.yml"), + "painn": Path("tests/core/models/test_configs/test_painn.yml"), + "painn_hydra": Path("tests/core/models/test_configs/test_painn_hydra.yml"), "equiformer_v2": Path("tests/core/models/test_configs/test_equiformerv2.yml"), + "equiformer_v2_hydra": Path( + "tests/core/models/test_configs/test_equiformerv2_hydra.yml" + ), } @@ -183,7 +206,7 @@ def smoke_test_train(self, input_yaml, tutorial_val_src, otf_norms=False): rundir=str(train_rundir), input_yaml=input_yaml, update_dict_with={ - "optim": {"max_epochs": 2, "eval_every": 8}, + "optim": {"max_epochs": 2, "eval_every": 8, "batch_size": 5}, "dataset": oc20_lmdb_train_and_val_from_paths( train_src=str(tutorial_val_src), val_src=str(tutorial_val_src), @@ -205,7 +228,7 @@ def smoke_test_train(self, input_yaml, tutorial_val_src, otf_norms=False): rundir=str(predictions_rundir), input_yaml=input_yaml, update_dict_with={ - "optim": {"max_epochs": 2, "eval_every": 8}, + "optim": {"max_epochs": 2, "eval_every": 8, "batch_size": 5}, "dataset": oc20_lmdb_train_and_val_from_paths( train_src=str(tutorial_val_src), val_src=str(tutorial_val_src), @@ -242,11 +265,22 @@ def smoke_test_train(self, input_yaml, tutorial_val_src, otf_norms=False): @pytest.mark.parametrize( ("model_name", "otf_norms"), [ - ("gemnet", False), - ("escn", False), - ("escn", True), - ("equiformer_v2", False), - ("equiformer_v2", True), + pytest.param("schnet", id="schnet"), + pytest.param("scn", id="scn"), + pytest.param("gemnet_dt", id="gemnet_dt"), + pytest.param("gemnet_dt_hydra", id="gemnet_dt_hydra"), + pytest.param("gemnet_dt_hydra_grad", id="gemnet_dt_hydra_grad"), + pytest.param("gemnet_oc", id="gemnet_oc"), + pytest.param("gemnet_oc_hydra", id="gemnet_oc_hydra"), + pytest.param("gemnet_oc_hydra_grad", id="gemnet_oc_hydra_grad"), + pytest.param("dimenet++", id="dimenet++"), + pytest.param("dimenet++_hydra", id="dimenet++_hydra"), + pytest.param("painn", id="painn"), + pytest.param("painn_hydra", id="painn_hydra"), + pytest.param("escn", id="escn"), + pytest.param("escn_hydra", id="escn_hydra"), + pytest.param("equiformer_v2", id="equiformer_v2"), + pytest.param("equiformer_v2_hydra", id="equiformer_v2_hydra"), ], ) def test_train_and_predict( @@ -339,9 +373,9 @@ class TestSmallDatasetOptim: @pytest.mark.parametrize( ("model_name", "expected_energy_mae", "expected_force_mae"), [ - pytest.param("gemnet", 0.4, 0.06, id="gemnet"), - pytest.param("escn", 0.4, 0.06, id="escn"), - pytest.param("equiformer_v2", 0.4, 0.06, id="equiformer_v2"), + pytest.param("gemnet_oc", 0.41, 0.06, id="gemnet_oc"), + pytest.param("escn", 0.41, 0.06, id="escn"), + pytest.param("equiformer_v2", 0.41, 0.06, id="equiformer_v2"), ], ) def test_train_optimization( diff --git a/tests/core/models/test_configs/test_dpp.yml b/tests/core/models/test_configs/test_dpp.yml new file mode 100755 index 000000000..a79294bd1 --- /dev/null +++ b/tests/core/models/test_configs/test_dpp.yml @@ -0,0 +1,50 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: dimenetplusplus #_bbwheads + hidden_channels: 4 + out_emb_channels: 4 + num_blocks: 3 + cutoff: 6.0 + num_radial: 6 + num_spherical: 7 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + regress_forces: True + use_pbc: True + +# *** Important note *** +# The total number of gpus used for this run was 256. +# If the global batch size (num_gpus * batch_size) is modified +# the lr_milestones and warmup_steps need to be adjusted accordingly. + +optim: + batch_size: 5 + eval_batch_size: 2 + eval_every: 1000 + num_workers: 8 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 130794 + - 196192 + - 261589 + warmup_steps: 130794 + warmup_factor: 0.2 + max_epochs: 7 diff --git a/tests/core/models/test_configs/test_dpp_hydra.yml b/tests/core/models/test_configs/test_dpp_hydra.yml new file mode 100755 index 000000000..1120cc905 --- /dev/null +++ b/tests/core/models/test_configs/test_dpp_hydra.yml @@ -0,0 +1,55 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: hydra + backbone: + model: dimenetplusplus_backbone + hidden_channels: 4 + out_emb_channels: 4 + num_blocks: 1 + cutoff: 6.0 + num_radial: 6 + num_spherical: 7 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + regress_forces: True + use_pbc: True + heads: + energy: + module: dimenetplusplus_energy_and_force_head + +# *** Important note *** +# The total number of gpus used for this run was 256. +# If the global batch size (num_gpus * batch_size) is modified +# the lr_milestones and warmup_steps need to be adjusted accordingly. + +optim: + batch_size: 5 + eval_batch_size: 2 + eval_every: 1000 + num_workers: 8 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 130794 + - 196192 + - 261589 + warmup_steps: 130794 + warmup_factor: 0.2 + max_epochs: 7 diff --git a/tests/core/models/test_configs/test_equiformerv2_hydra.yml b/tests/core/models/test_configs/test_equiformerv2_hydra.yml new file mode 100644 index 000000000..4c00fe6a2 --- /dev/null +++ b/tests/core/models/test_configs/test_equiformerv2_hydra.yml @@ -0,0 +1,98 @@ + + +trainer: forces + +model: + name: hydra + backbone: + model: equiformer_v2_backbone + use_pbc: True + regress_forces: True + otf_graph: True + + enforce_max_neighbors_strictly: False + + max_neighbors: 1 + max_radius: 12.0 + max_num_elements: 90 + + num_layers: 1 + sphere_channels: 4 + attn_hidden_channels: 4 # [64, 96] This determines the hidden size of message passing. Do not necessarily use 96. + num_heads: 1 + attn_alpha_channels: 4 # Not used when `use_s2_act_attn` is True. + attn_value_channels: 4 + ffn_hidden_channels: 8 + norm_type: 'layer_norm_sh' # ['rms_norm_sh', 'layer_norm', 'layer_norm_sh'] + + lmax_list: [1] + mmax_list: [1] + grid_resolution: 18 # [18, 16, 14, None] For `None`, simply comment this line. + + num_sphere_samples: 128 + + edge_channels: 32 + use_atom_edge_embedding: True + distance_function: 'gaussian' + num_distance_basis: 16 # not used + + attn_activation: 'silu' + use_s2_act_attn: False # [False, True] Switch between attention after S2 activation or the original EquiformerV1 attention. + ffn_activation: 'silu' # ['silu', 'swiglu'] + use_gate_act: False # [True, False] Switch between gate activation and S2 activation + use_grid_mlp: False # [False, True] If `True`, use projecting to grids and performing MLPs for FFNs. + + alpha_drop: 0.0 # [0.0, 0.1] + drop_path_rate: 0.0 # [0.0, 0.05] + proj_drop: 0.0 + + weight_init: 'normal' # ['uniform', 'normal'] + heads: + energy: + module: equiformer_v2_energy_head + forces: + module: equiformer_v2_force_head + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 20 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_escn_hydra.yml b/tests/core/models/test_configs/test_escn_hydra.yml new file mode 100644 index 000000000..ba5db1f53 --- /dev/null +++ b/tests/core/models/test_configs/test_escn_hydra.yml @@ -0,0 +1,67 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: escn_backbone + num_layers: 2 + max_neighbors: 10 + cutoff: 12.0 + sphere_channels: 8 + hidden_channels: 8 + lmax_list: [2] + mmax_list: [2] + num_sphere_samples: 64 + distance_function: "gaussian" + regress_forces: True + use_pbc: True + basis_width_scalar: 2.0 + otf_graph: True + heads: + energy: + module: escn_energy_head + forces: + module: escn_force_head + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 20 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_gemnet_dt.yml b/tests/core/models/test_configs/test_gemnet_dt.yml new file mode 100644 index 000000000..b04b6dfda --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_dt.yml @@ -0,0 +1,79 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: gemnet_t + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip: 4 + emb_size_rbf: 4 + emb_size_cbf: 4 + emb_size_bil_trip: 4 + num_before_skip: 1 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + cutoff: 6.0 + max_neighbors: 50 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + extensive: True + otf_graph: False + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json + + regress_forces: True + direct_forces: True + +optim: + batch_size: 8 + eval_batch_size: 8 + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_gemnet_dt_hydra.yml b/tests/core/models/test_configs/test_gemnet_dt_hydra.yml new file mode 100644 index 000000000..a61274147 --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_dt_hydra.yml @@ -0,0 +1,86 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_t_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip: 4 + emb_size_rbf: 4 + emb_size_cbf: 4 + emb_size_bil_trip: 4 + num_before_skip: 1 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + cutoff: 6.0 + max_neighbors: 50 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + extensive: True + otf_graph: False + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json + + regress_forces: True + direct_forces: True + heads: + energy: + module: gemnet_t_energy_and_grad_force_head + forces: + module: gemnet_t_force_head + +optim: + batch_size: 8 + eval_batch_size: 8 + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml b/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml new file mode 100644 index 000000000..83d46bdd4 --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml @@ -0,0 +1,84 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_t_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip: 4 + emb_size_rbf: 4 + emb_size_cbf: 4 + emb_size_bil_trip: 4 + num_before_skip: 1 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + cutoff: 6.0 + max_neighbors: 50 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + extensive: True + otf_graph: False + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json + + regress_forces: True + direct_forces: False + heads: + energy_and_forces: + module: gemnet_t_energy_and_grad_force_head + +optim: + batch_size: 8 + eval_batch_size: 8 + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_gemnet.yml b/tests/core/models/test_configs/test_gemnet_oc.yml similarity index 100% rename from tests/core/models/test_configs/test_gemnet.yml rename to tests/core/models/test_configs/test_gemnet_oc.yml diff --git a/tests/core/models/test_configs/test_gemnet_oc_hydra.yml b/tests/core/models/test_configs/test_gemnet_oc_hydra.yml new file mode 100644 index 000000000..97343e90e --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_oc_hydra.yml @@ -0,0 +1,112 @@ + + + +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_oc_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip_in: 4 + emb_size_trip_out: 4 + emb_size_quad_in: 2 + emb_size_quad_out: 2 + emb_size_aint_in: 4 + emb_size_aint_out: 4 + emb_size_rbf: 2 + emb_size_cbf: 2 + emb_size_sbf: 4 + num_before_skip: 1 + num_after_skip: 1 + num_concat: 1 + num_atom: 3 + num_output_afteratom: 3 + cutoff: 12.0 + cutoff_qint: 12.0 + cutoff_aeaint: 12.0 + cutoff_aint: 12.0 + max_neighbors: 30 + max_neighbors_qint: 8 + max_neighbors_aeaint: 20 + max_neighbors_aint: 1000 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + sbf: + name: legendre_outer + extensive: True + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-oc.pt + + regress_forces: True + direct_forces: True + forces_coupled: False + + quad_interaction: True + atom_edge_interaction: True + edge_atom_interaction: True + atom_interaction: True + + num_atom_emb_layers: 2 + num_global_out_layers: 2 + qint_tags: [1, 2] + heads: + energy: + module: gemnet_oc_energy_and_grad_force_head + num_global_out_layers: 2 + forces: + module: gemnet_oc_force_head + num_global_out_layers: 2 + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 10 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml b/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml new file mode 100644 index 000000000..334c3cb4d --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml @@ -0,0 +1,109 @@ + + + +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_oc_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip_in: 4 + emb_size_trip_out: 4 + emb_size_quad_in: 2 + emb_size_quad_out: 2 + emb_size_aint_in: 4 + emb_size_aint_out: 4 + emb_size_rbf: 2 + emb_size_cbf: 2 + emb_size_sbf: 4 + num_before_skip: 1 + num_after_skip: 1 + num_concat: 1 + num_atom: 3 + num_output_afteratom: 3 + cutoff: 12.0 + cutoff_qint: 12.0 + cutoff_aeaint: 12.0 + cutoff_aint: 12.0 + max_neighbors: 30 + max_neighbors_qint: 8 + max_neighbors_aeaint: 20 + max_neighbors_aint: 1000 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + sbf: + name: legendre_outer + extensive: True + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-oc.pt + + regress_forces: True + direct_forces: False + forces_coupled: False + + quad_interaction: True + atom_edge_interaction: True + edge_atom_interaction: True + atom_interaction: True + + num_atom_emb_layers: 2 + num_global_out_layers: 2 + qint_tags: [1, 2] + heads: + energy: + module: gemnet_oc_energy_and_grad_force_head + num_global_out_layers: 2 + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 10 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_painn.yml b/tests/core/models/test_configs/test_painn.yml new file mode 100644 index 000000000..c1f24d0bb --- /dev/null +++ b/tests/core/models/test_configs/test_painn.yml @@ -0,0 +1,50 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: painn #_bbwheads + hidden_channels: 32 + num_layers: 6 + num_rbf: 32 + cutoff: 12.0 + max_neighbors: 5 + scale_file: configs/s2ef/all/painn/painn_nb6_scaling_factors.pt + regress_forces: True + direct_forces: True + use_pbc: True + +optim: + batch_size: 32 + eval_batch_size: 32 + load_balancing: atoms + eval_every: 5000 + num_workers: 2 + optimizer: AdamW + optimizer_params: + amsgrad: True + weight_decay: 0. # 2e-6 (TF weight decay) / 1e-4 (lr) = 2e-2 + lr_initial: 1.e-4 + lr_gamma: 0.8 + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_painn_hydra.yml b/tests/core/models/test_configs/test_painn_hydra.yml new file mode 100644 index 000000000..0b39aa173 --- /dev/null +++ b/tests/core/models/test_configs/test_painn_hydra.yml @@ -0,0 +1,58 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: hydra + backbone: + model: painn_backbone #_bbwheads + hidden_channels: 32 + num_layers: 6 + num_rbf: 32 + cutoff: 12.0 + max_neighbors: 5 + scale_file: configs/s2ef/all/painn/painn_nb6_scaling_factors.pt + regress_forces: True + direct_forces: True + use_pbc: True + heads: + energy: + module: painn_energy_head + forces: + module: painn_force_head + + +optim: + batch_size: 32 + eval_batch_size: 32 + load_balancing: atoms + eval_every: 5000 + num_workers: 2 + optimizer: AdamW + optimizer_params: + amsgrad: True + weight_decay: 0. # 2e-6 (TF weight decay) / 1e-4 (lr) = 2e-2 + lr_initial: 1.e-4 + lr_gamma: 0.8 + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_schnet.yml b/tests/core/models/test_configs/test_schnet.yml new file mode 100755 index 000000000..97faf3962 --- /dev/null +++ b/tests/core/models/test_configs/test_schnet.yml @@ -0,0 +1,45 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: schnet + hidden_channels: 1024 + num_filters: 256 + num_interactions: 5 + num_gaussians: 200 + cutoff: 6.0 + use_pbc: True + +# *** Important note *** +# The total number of gpus used for this run was 64. +# If the global batch size (num_gpus * batch_size) is modified +# the lr_milestones and warmup_steps need to be adjusted accordingly. + +optim: + batch_size: 20 + eval_batch_size: 20 + eval_every: 10000 + num_workers: 16 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 313907 + - 523179 + - 732451 + warmup_steps: 209271 + warmup_factor: 0.2 + max_epochs: 15 diff --git a/tests/core/models/test_configs/test_scn.yml b/tests/core/models/test_configs/test_scn.yml new file mode 100755 index 000000000..c080c4855 --- /dev/null +++ b/tests/core/models/test_configs/test_scn.yml @@ -0,0 +1,59 @@ +# A total of 64 32GB GPUs were used for training. +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: scn + num_interactions: 2 + hidden_channels: 16 + sphere_channels: 8 + sphere_channels_reduce: 8 + num_sphere_samples: 8 + num_basis_functions: 8 + distance_function: "gaussian" + show_timing_info: False + max_num_neighbors: 40 + cutoff: 8.0 + lmax: 4 + num_bands: 2 + use_grid: True + regress_forces: True + use_pbc: True + basis_width_scalar: 2.0 + otf_graph: True + +optim: + batch_size: 2 + eval_batch_size: 1 + num_workers: 2 + lr_initial: 0.0004 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + eval_every: 5000 + lr_gamma: 0.3 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 260000 + - 340000 + - 420000 + - 500000 + - 800000 + - 1000000 + warmup_steps: 100 + warmup_factor: 0.2 + max_epochs: 12 + clip_grad_norm: 100 + ema_decay: 0.999 diff --git a/tests/core/models/test_dimenetpp.py b/tests/core/models/test_dimenetpp.py index 76a546037..d1daec728 100644 --- a/tests/core/models/test_dimenetpp.py +++ b/tests/core/models/test_dimenetpp.py @@ -47,9 +47,6 @@ def load_model(request) -> None: setup_imports() model = registry.get_model_class("dimenetplusplus")( - None, - 32, - 1, cutoff=6.0, regress_forces=True, use_pbc=False, diff --git a/tests/core/models/test_equiformer_v2.py b/tests/core/models/test_equiformer_v2.py index 34ed79ba2..88230874e 100644 --- a/tests/core/models/test_equiformer_v2.py +++ b/tests/core/models/test_equiformer_v2.py @@ -59,9 +59,6 @@ def _load_model(): checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("equiformer_v2")( - None, - -1, - 1, use_pbc=True, regress_forces=True, otf_graph=True, diff --git a/tests/core/models/test_gemnet.py b/tests/core/models/test_gemnet.py index 3fa0c6bab..b4c5414cc 100644 --- a/tests/core/models/test_gemnet.py +++ b/tests/core/models/test_gemnet.py @@ -47,9 +47,6 @@ def load_model(request) -> None: setup_imports() model = registry.get_model_class("gemnet_t")( - None, - -1, - 1, cutoff=6.0, num_spherical=7, num_radial=128, diff --git a/tests/core/models/test_gemnet_oc.py b/tests/core/models/test_gemnet_oc.py index d84669750..7729c1448 100644 --- a/tests/core/models/test_gemnet_oc.py +++ b/tests/core/models/test_gemnet_oc.py @@ -58,9 +58,6 @@ def load_model(request) -> None: checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, diff --git a/tests/core/models/test_gemnet_oc_scaling_mismatch.py b/tests/core/models/test_gemnet_oc_scaling_mismatch.py index 8f1c36d27..29ea40c0f 100644 --- a/tests/core/models/test_gemnet_oc_scaling_mismatch.py +++ b/tests/core/models/test_gemnet_oc_scaling_mismatch.py @@ -35,9 +35,6 @@ def test_no_scaling_mismatch(self) -> None: checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, @@ -111,9 +108,6 @@ def test_scaling_mismatch(self) -> None: checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, @@ -189,9 +183,6 @@ def test_no_file_exists(self) -> None: with pytest.raises(ValueError): registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, @@ -245,9 +236,6 @@ def test_not_fitted(self) -> None: setup_imports() model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, diff --git a/tests/core/models/test_schnet.py b/tests/core/models/test_schnet.py index aa704604f..3dd21be4e 100644 --- a/tests/core/models/test_schnet.py +++ b/tests/core/models/test_schnet.py @@ -46,7 +46,7 @@ def load_model(request) -> None: setup_imports() model = registry.get_model_class("schnet")( - None, 32, 1, cutoff=6.0, regress_forces=True, use_pbc=True + cutoff=6.0, regress_forces=True, use_pbc=True ) request.cls.model = model