Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move select models to backbone + heads format and add support for hydra #782

Merged
merged 46 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
abfdd98
convert escn to bb + heads
misko Jul 24, 2024
41065e9
convert dimenet to bb + heads
misko Jul 24, 2024
fd0ab8d
gemnet_oc to backbone and heads
misko Jul 24, 2024
c0d9da2
add additional parameter backbone config to heads
misko Jul 24, 2024
42f1a11
gemnet to bb and heads
misko Jul 24, 2024
91a538a
pain to bb and heads
misko Jul 24, 2024
c489cfc
add eqv2 bb+heads; move to canonical naming
misko Jul 25, 2024
da41bb1
fix calculator loading by leaving original class in code
misko Jul 26, 2024
cf0eb28
fix issues with calculator loading
misko Jul 26, 2024
ea3e967
lint fixes
misko Jul 26, 2024
3db53c3
move dimenet++ heads to one
misko Jul 29, 2024
73f89be
add test for dimenet
misko Jul 29, 2024
111d19e
add painn test
misko Jul 29, 2024
817e9fe
hydra and tests for gemnetH dppH painnH
misko Jul 30, 2024
0e72dd3
add escnH and equiformerv2H
misko Jul 30, 2024
ca807d3
add gemnetdt gemnetdtH
misko Jul 30, 2024
b9a2ff3
add smoke test for schnet and scn
misko Jul 30, 2024
52000ec
remove old examples
misko Jul 30, 2024
39f5e2e
typo
misko Jul 30, 2024
39f7fc6
fix gemnet with grad forces; add test for this
misko Jul 30, 2024
01689b6
remove unused params; add backbone and head interface; add typing
misko Jul 31, 2024
da02e04
remove unused second order output heads
misko Jul 31, 2024
eac0252
remove OC20 suffix from equiformer
misko Jul 31, 2024
7e5170f
remove comment
misko Jul 31, 2024
9154523
rename and lint
misko Jul 31, 2024
e2e5010
fix dimenet test
misko Jul 31, 2024
d753342
fix tests
misko Jul 31, 2024
f866322
Merge branch 'main' into hydra_support
misko Jul 31, 2024
366a42b
refactor generate graph
lbluque Aug 1, 2024
d65a7fe
refactor generate graph
lbluque Aug 1, 2024
18bcff2
fix a messy cherry pick
lbluque Aug 1, 2024
e5ceab8
final messy fix
lbluque Aug 1, 2024
fb7112e
graph data interface in eqv2
lbluque Aug 1, 2024
2e67c44
refactor
lbluque Aug 2, 2024
9bc3306
no bbconfigs
lbluque Aug 2, 2024
07fa4ac
no more headconfigs in inits
lbluque Aug 2, 2024
5867788
rename hydra
lbluque Aug 2, 2024
cd517a5
fix eqV2
lbluque Aug 2, 2024
384aba5
update test configs
lbluque Aug 2, 2024
c808fd9
final fixes
lbluque Aug 2, 2024
53dd05c
fix tutorial
lbluque Aug 2, 2024
d7a98ee
rm comments
lbluque Aug 2, 2024
23346b5
Merge pull request #791 from FAIR-Chem/more_hydra_support
lbluque Aug 2, 2024
51a11a2
Merge branch 'main' into hydra_support
misko Aug 2, 2024
00009fb
merge
misko Aug 2, 2024
e8e7eb7
fix test
misko Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/legacy_tutorials/OCP_Tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,7 @@ Similarly, to predict forces, we pass edge features through a fully-connected la

@registry.register_model("simple")
class SimpleAtomEdgeModel(torch.nn.Module):
def __init__(self, num_atoms, bond_feat_dim, num_targets, emb_size=64, num_radial=64, cutoff=6.0, env_exponent=5):
def __init__(self, emb_size=64, num_radial=64, cutoff=6.0, env_exponent=5):
super().__init__()

self.radial_basis = RadialBasis(
Expand Down
137 changes: 121 additions & 16 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,42 @@
from __future__ import annotations

import logging
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING

import torch
import torch.nn as nn
from torch import nn
from torch_geometric.nn import radius_graph

from fairchem.core.common.registry import registry
from fairchem.core.common.utils import (
compute_neighbors,
get_pbc_distances,
radius_graph_pbc,
)

if TYPE_CHECKING:
from torch_geometric.data import Batch
lbluque marked this conversation as resolved.
Show resolved Hide resolved

class BaseModel(nn.Module):
def __init__(self, num_atoms=None, bond_feat_dim=None, num_targets=None) -> None:
super().__init__()
self.num_atoms = num_atoms
self.bond_feat_dim = bond_feat_dim
self.num_targets = num_targets

def forward(self, data):
raise NotImplementedError
@dataclass
class GraphData:
"""Class to keep graph attributes nicely packaged."""

edge_index: torch.Tensor
edge_distance: torch.Tensor
edge_distance_vec: torch.Tensor
cell_offsets: torch.Tensor
offset_distances: torch.Tensor
neighbors: torch.Tensor
batch_full: torch.Tensor # used for GP functionality
atomic_numbers_full: torch.Tensor # used for GP functionality
node_offset: int = 0 # used for GP functionality


class GraphModelMixin:
"""Mixin Model class implementing some general convenience properties and methods."""

def generate_graph(
self,
Expand Down Expand Up @@ -109,13 +124,16 @@ def generate_graph(
)
neighbors = compute_neighbors(data, edge_index)

return (
edge_index,
edge_dist,
distance_vec,
cell_offsets,
cell_offset_distances,
neighbors,
return GraphData(
edge_index=edge_index,
edge_distance=edge_dist,
edge_distance_vec=distance_vec,
cell_offsets=cell_offsets,
offset_distances=cell_offset_distances,
neighbors=neighbors,
node_offset=0,
batch_full=data.batch,
atomic_numbers_full=data.atomic_numbers.long(),
)

@property
Expand All @@ -130,3 +148,90 @@ def no_weight_decay(self) -> list:
if "embedding" in name or "frequencies" in name or "bias" in name:
no_wd_list.append(name)
return no_wd_list


class HeadInterface(metaclass=ABCMeta):
@abstractmethod
def forward(
self, data: Batch, emb: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""Head forward.

Arguments
---------
data: DataBatch
Atomic systems as input
emb: dict[str->torch.Tensor]
Embeddings of the input as generated by the backbone

Returns
-------
outputs: dict[str->torch.Tensor]
Return one or more targets generated by this head
"""
return


class BackboneInterface(metaclass=ABCMeta):
@abstractmethod
def forward(self, data: Batch) -> dict[str, torch.Tensor]:
"""Backbone forward.

Arguments
---------
data: DataBatch
Atomic systems as input

Returns
-------
embedding: dict[str->torch.Tensor]
Return backbone embeddings for the given input
"""
return


@registry.register_model("hydra")
class HydraModel(nn.Module, GraphModelMixin):
def __init__(
self,
backbone: dict,
heads: dict,
otf_graph: bool = True,
):
super().__init__()
self.otf_graph = otf_graph

backbone_model_name = backbone.pop("model")
self.backbone: BackboneInterface = registry.get_model_class(
backbone_model_name
)(
**backbone,
)

# Iterate through outputs_cfg and create heads
self.output_heads: dict[str, HeadInterface] = {}

head_names_sorted = sorted(heads.keys())
for head_name in head_names_sorted:
head_config = heads[head_name]
if "module" not in head_config:
raise ValueError(
f"{head_name} head does not specify module to use for the head"
)

module_name = head_config.pop("module")
self.output_heads[head_name] = registry.get_model_class(module_name)(
self.backbone,
**head_config,
)

self.output_heads = torch.nn.ModuleDict(self.output_heads)

def forward(self, data: Batch):
emb = self.backbone(data)
# Predict all output properties for all structures in the batch for now.
out = {}
for k in self.output_heads:
out.update(self.output_heads[k](data, emb))

return out
Loading