Skip to content

Commit

Permalink
Move select models to backbone + heads format and add support for hyd…
Browse files Browse the repository at this point in the history
…ra (#782)

* convert escn to bb + heads

* convert dimenet to bb + heads

* gemnet_oc to backbone and heads

* add additional parameter backbone config to heads

* gemnet to bb and heads

* pain to bb and heads

* add eqv2 bb+heads; move to canonical naming

* fix calculator loading by leaving original class in code

* fix issues with calculator loading

* lint fixes

* move dimenet++ heads to one

* add test for dimenet

* add painn test

* hydra and tests for gemnetH dppH painnH

* add escnH and equiformerv2H

* add gemnetdt gemnetdtH

* add smoke test for schnet and scn

* remove old examples

* typo

* fix gemnet with grad forces; add test for this

* remove unused params; add backbone and head interface; add typing

* remove unused second order output heads

* remove OC20 suffix from equiformer

* remove comment

* rename and lint

* fix dimenet test

* fix tests

* refactor generate graph

* refactor generate graph

* fix a messy cherry pick

* final messy fix

* graph data interface in eqv2

* refactor

* no bbconfigs

* no more headconfigs in inits

* rename hydra

* fix eqV2

* update test configs

* final fixes

* fix tutorial

* rm comments

* fix test

---------

Co-authored-by: lbluque <[email protected]>
Co-authored-by: Luis Barroso-Luque <[email protected]>
  • Loading branch information
3 people committed Aug 2, 2024
1 parent 04a69b0 commit 08b8c1e
Show file tree
Hide file tree
Showing 34 changed files with 2,182 additions and 351 deletions.
2 changes: 1 addition & 1 deletion docs/legacy_tutorials/OCP_Tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,7 @@ Similarly, to predict forces, we pass edge features through a fully-connected la

@registry.register_model("simple")
class SimpleAtomEdgeModel(torch.nn.Module):
def __init__(self, num_atoms, bond_feat_dim, num_targets, emb_size=64, num_radial=64, cutoff=6.0, env_exponent=5):
def __init__(self, emb_size=64, num_radial=64, cutoff=6.0, env_exponent=5):
super().__init__()

self.radial_basis = RadialBasis(
Expand Down
137 changes: 121 additions & 16 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,42 @@
from __future__ import annotations

import logging
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING

import torch
import torch.nn as nn
from torch import nn
from torch_geometric.nn import radius_graph

from fairchem.core.common.registry import registry
from fairchem.core.common.utils import (
compute_neighbors,
get_pbc_distances,
radius_graph_pbc,
)

if TYPE_CHECKING:
from torch_geometric.data import Batch

class BaseModel(nn.Module):
def __init__(self, num_atoms=None, bond_feat_dim=None, num_targets=None) -> None:
super().__init__()
self.num_atoms = num_atoms
self.bond_feat_dim = bond_feat_dim
self.num_targets = num_targets

def forward(self, data):
raise NotImplementedError
@dataclass
class GraphData:
"""Class to keep graph attributes nicely packaged."""

edge_index: torch.Tensor
edge_distance: torch.Tensor
edge_distance_vec: torch.Tensor
cell_offsets: torch.Tensor
offset_distances: torch.Tensor
neighbors: torch.Tensor
batch_full: torch.Tensor # used for GP functionality
atomic_numbers_full: torch.Tensor # used for GP functionality
node_offset: int = 0 # used for GP functionality


class GraphModelMixin:
"""Mixin Model class implementing some general convenience properties and methods."""

def generate_graph(
self,
Expand Down Expand Up @@ -109,13 +124,16 @@ def generate_graph(
)
neighbors = compute_neighbors(data, edge_index)

return (
edge_index,
edge_dist,
distance_vec,
cell_offsets,
cell_offset_distances,
neighbors,
return GraphData(
edge_index=edge_index,
edge_distance=edge_dist,
edge_distance_vec=distance_vec,
cell_offsets=cell_offsets,
offset_distances=cell_offset_distances,
neighbors=neighbors,
node_offset=0,
batch_full=data.batch,
atomic_numbers_full=data.atomic_numbers.long(),
)

@property
Expand All @@ -130,3 +148,90 @@ def no_weight_decay(self) -> list:
if "embedding" in name or "frequencies" in name or "bias" in name:
no_wd_list.append(name)
return no_wd_list


class HeadInterface(metaclass=ABCMeta):
@abstractmethod
def forward(
self, data: Batch, emb: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""Head forward.
Arguments
---------
data: DataBatch
Atomic systems as input
emb: dict[str->torch.Tensor]
Embeddings of the input as generated by the backbone
Returns
-------
outputs: dict[str->torch.Tensor]
Return one or more targets generated by this head
"""
return


class BackboneInterface(metaclass=ABCMeta):
@abstractmethod
def forward(self, data: Batch) -> dict[str, torch.Tensor]:
"""Backbone forward.
Arguments
---------
data: DataBatch
Atomic systems as input
Returns
-------
embedding: dict[str->torch.Tensor]
Return backbone embeddings for the given input
"""
return


@registry.register_model("hydra")
class HydraModel(nn.Module, GraphModelMixin):
def __init__(
self,
backbone: dict,
heads: dict,
otf_graph: bool = True,
):
super().__init__()
self.otf_graph = otf_graph

backbone_model_name = backbone.pop("model")
self.backbone: BackboneInterface = registry.get_model_class(
backbone_model_name
)(
**backbone,
)

# Iterate through outputs_cfg and create heads
self.output_heads: dict[str, HeadInterface] = {}

head_names_sorted = sorted(heads.keys())
for head_name in head_names_sorted:
head_config = heads[head_name]
if "module" not in head_config:
raise ValueError(
f"{head_name} head does not specify module to use for the head"
)

module_name = head_config.pop("module")
self.output_heads[head_name] = registry.get_model_class(module_name)(
self.backbone,
**head_config,
)

self.output_heads = torch.nn.ModuleDict(self.output_heads)

def forward(self, data: Batch):
emb = self.backbone(data)
# Predict all output properties for all structures in the batch for now.
out = {}
for k in self.output_heads:
out.update(self.output_heads[k](data, emb))

return out
Loading

0 comments on commit 08b8c1e

Please sign in to comment.