Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move select models to backbone + heads format and add support for hydra #782

Merged
merged 46 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
abfdd98
convert escn to bb + heads
misko Jul 24, 2024
41065e9
convert dimenet to bb + heads
misko Jul 24, 2024
fd0ab8d
gemnet_oc to backbone and heads
misko Jul 24, 2024
c0d9da2
add additional parameter backbone config to heads
misko Jul 24, 2024
42f1a11
gemnet to bb and heads
misko Jul 24, 2024
91a538a
pain to bb and heads
misko Jul 24, 2024
c489cfc
add eqv2 bb+heads; move to canonical naming
misko Jul 25, 2024
da41bb1
fix calculator loading by leaving original class in code
misko Jul 26, 2024
cf0eb28
fix issues with calculator loading
misko Jul 26, 2024
ea3e967
lint fixes
misko Jul 26, 2024
3db53c3
move dimenet++ heads to one
misko Jul 29, 2024
73f89be
add test for dimenet
misko Jul 29, 2024
111d19e
add painn test
misko Jul 29, 2024
817e9fe
hydra and tests for gemnetH dppH painnH
misko Jul 30, 2024
0e72dd3
add escnH and equiformerv2H
misko Jul 30, 2024
ca807d3
add gemnetdt gemnetdtH
misko Jul 30, 2024
b9a2ff3
add smoke test for schnet and scn
misko Jul 30, 2024
52000ec
remove old examples
misko Jul 30, 2024
39f5e2e
typo
misko Jul 30, 2024
39f7fc6
fix gemnet with grad forces; add test for this
misko Jul 30, 2024
01689b6
remove unused params; add backbone and head interface; add typing
misko Jul 31, 2024
da02e04
remove unused second order output heads
misko Jul 31, 2024
eac0252
remove OC20 suffix from equiformer
misko Jul 31, 2024
7e5170f
remove comment
misko Jul 31, 2024
9154523
rename and lint
misko Jul 31, 2024
e2e5010
fix dimenet test
misko Jul 31, 2024
d753342
fix tests
misko Jul 31, 2024
f866322
Merge branch 'main' into hydra_support
misko Jul 31, 2024
366a42b
refactor generate graph
lbluque Aug 1, 2024
d65a7fe
refactor generate graph
lbluque Aug 1, 2024
18bcff2
fix a messy cherry pick
lbluque Aug 1, 2024
e5ceab8
final messy fix
lbluque Aug 1, 2024
fb7112e
graph data interface in eqv2
lbluque Aug 1, 2024
2e67c44
refactor
lbluque Aug 2, 2024
9bc3306
no bbconfigs
lbluque Aug 2, 2024
07fa4ac
no more headconfigs in inits
lbluque Aug 2, 2024
5867788
rename hydra
lbluque Aug 2, 2024
cd517a5
fix eqV2
lbluque Aug 2, 2024
384aba5
update test configs
lbluque Aug 2, 2024
c808fd9
final fixes
lbluque Aug 2, 2024
53dd05c
fix tutorial
lbluque Aug 2, 2024
d7a98ee
rm comments
lbluque Aug 2, 2024
23346b5
Merge pull request #791 from FAIR-Chem/more_hydra_support
lbluque Aug 2, 2024
51a11a2
Merge branch 'main' into hydra_support
misko Aug 2, 2024
00009fb
merge
misko Aug 2, 2024
e8e7eb7
fix test
misko Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 97 additions & 11 deletions src/fairchem/core/models/dimenet_plus_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ def __init__(
act = activation_resolver(act)

super().__init__()

self.cutoff = cutoff

if sym is None:
Expand Down Expand Up @@ -330,6 +329,33 @@ def forward(self, z, pos, batch=None):
raise NotImplementedError


@registry.register_model("dimenetplusplus_energy_and_force_head")
class DimeNetPlusPlusWrap_energy_and_force_head(nn.Module):
lbluque marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, backbone, backbone_config, head_config):
super().__init__()
self.regress_forces = backbone.regress_forces

@conditional_grad(torch.enable_grad())
def forward(self, x, emb):
lbluque marked this conversation as resolved.
Show resolved Hide resolved
outputs = {
"energy": (
emb["P"].sum(dim=0)
if x.batch is None
else scatter(emb["P"], x.batch, dim=0)
)
}
if self.regress_forces:
outputs["forces"] = -1 * (
torch.autograd.grad(
outputs["energy"],
x.pos,
grad_outputs=torch.ones_like(outputs["energy"]),
create_graph=True,
)[0]
)
return outputs


@registry.register_model("dimenetplusplus")
class DimeNetPlusPlusWrap(DimeNetPlusPlus, BaseModel):
def __init__(
Expand Down Expand Up @@ -441,16 +467,13 @@ def forward(self, data):
outputs = {"energy": energy}

if self.regress_forces:
forces = (
-1
* (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
forces = -1 * (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
outputs["forces"] = forces

Expand All @@ -459,3 +482,66 @@ def forward(self, data):
@property
def num_params(self) -> int:
return sum(p.numel() for p in self.parameters())


@registry.register_model("dimenetplusplus_backbone")
class DimeNetPlusPlusWrapBB(DimeNetPlusPlusWrap):
lbluque marked this conversation as resolved.
Show resolved Hide resolved

@conditional_grad(torch.enable_grad())
def forward(self, data):
lbluque marked this conversation as resolved.
Show resolved Hide resolved
if self.regress_forces:
data.pos.requires_grad_(True)
pos = data.pos
(
edge_index,
dist,
_,
cell_offsets,
offsets,
neighbors,
) = self.generate_graph(data)

data.edge_index = edge_index
data.cell_offsets = cell_offsets
data.neighbors = neighbors
j, i = edge_index

_, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets(
edge_index,
data.cell_offsets,
num_nodes=data.atomic_numbers.size(0),
)

# Calculate angles.
pos_i = pos[idx_i].detach()
pos_j = pos[idx_j].detach()
if self.use_pbc:
pos_ji, pos_kj = (
pos[idx_j].detach() - pos_i + offsets[idx_ji],
pos[idx_k].detach() - pos_j + offsets[idx_kj],
)
else:
pos_ji, pos_kj = (
pos[idx_j].detach() - pos_i,
pos[idx_k].detach() - pos_j,
)

a = (pos_ji * pos_kj).sum(dim=-1)
b = torch.cross(pos_ji, pos_kj).norm(dim=-1)
angle = torch.atan2(b, a)

rbf = self.rbf(dist)
sbf = self.sbf(dist, angle, idx_kj)

# Embedding block.
x = self.emb(data.atomic_numbers.long(), rbf, i, j)
P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0))

# Interaction blocks.
for interaction_block, output_block in zip(
self.interaction_blocks, self.output_blocks[1:]
):
x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
P += output_block(x, rbf, i, num_nodes=pos.size(0))

return {"P": P, "edge_embedding": x, "edge_idx": i}
222 changes: 222 additions & 0 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,3 +678,225 @@ def no_weight_decay(self) -> set:
no_wd_list.append(global_parameter_name)

return set(no_wd_list)


@registry.register_model("equiformer_v2_backbone")
class EquiformerV2_OC20BB(EquiformerV2_OC20):
lbluque marked this conversation as resolved.
Show resolved Hide resolved

@conditional_grad(torch.enable_grad())
def forward(self, data):
self.batch_size = len(data.natoms)
self.dtype = data.pos.dtype
self.device = data.pos.device
atomic_numbers = data.atomic_numbers.long()

(
edge_index,
edge_distance,
edge_distance_vec,
cell_offsets,
_, # cell offset distances
neighbors,
) = self.generate_graph(
data,
enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly,
)

data_batch_full = data.batch
data_batch = data.batch
atomic_numbers_full = atomic_numbers
node_offset = 0
if gp_utils.initialized():
(
atomic_numbers,
data_batch,
node_offset,
edge_index,
edge_distance,
edge_distance_vec,
) = self._init_gp_partitions(
atomic_numbers_full,
data_batch_full,
edge_index,
edge_distance,
edge_distance_vec,
)
###############################################################
# Entering Graph Parallel Region
# after this point, if using gp, then node, edge tensors are split
# across the graph parallel ranks, some full tensors such as
# atomic_numbers_full are required because we need to index into the
# full graph when computing edge embeddings or reducing nodes from neighbors
#
# all tensors that do not have the suffix "_full" refer to the partial tensors.
# if not using gp, the full values are equal to the partial values
# ie: atomic_numbers_full == atomic_numbers
###############################################################

###############################################################
# Initialize data structures
###############################################################

# Compute 3x3 rotation matrix per edge
edge_rot_mat = self._init_edge_rot_mat(data, edge_index, edge_distance_vec)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
for i in range(self.num_resolutions):
self.SO3_rotation[i].set_wigner(edge_rot_mat)

###############################################################
# Initialize node embeddings
###############################################################

# Init per node representations using an atomic number based embedding
offset = 0
x = SO3_Embedding(
len(atomic_numbers),
self.lmax_list,
self.sphere_channels,
self.device,
self.dtype,
)

offset_res = 0
offset = 0
# Initialize the l = 0, m = 0 coefficients for each resolution
for i in range(self.num_resolutions):
if self.num_resolutions == 1:
x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)
else:
x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[
:, offset : offset + self.sphere_channels
]
offset = offset + self.sphere_channels
offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2)

# Edge encoding (distance and atom edge)
edge_distance = self.distance_expansion(edge_distance)
if self.share_atom_edge_embedding and self.use_atom_edge_embedding:
source_element = atomic_numbers_full[
edge_index[0]
] # Source atom atomic number
target_element = atomic_numbers_full[
edge_index[1]
] # Target atom atomic number
source_embedding = self.source_embedding(source_element)
target_embedding = self.target_embedding(target_element)
edge_distance = torch.cat(
(edge_distance, source_embedding, target_embedding), dim=1
)

# Edge-degree embedding
edge_degree = self.edge_degree_embedding(
atomic_numbers_full,
edge_distance,
edge_index,
len(atomic_numbers),
node_offset,
)
x.embedding = x.embedding + edge_degree.embedding

###############################################################
# Update spherical node embeddings
###############################################################

for i in range(self.num_layers):
x = self.blocks[i](
x, # SO3_Embedding
atomic_numbers_full,
edge_distance,
edge_index,
batch=data_batch, # for GraphDropPath
node_offset=node_offset,
)

# Final layer norm
x.embedding = self.norm(x.embedding)

return {
"node_embedding": x,
"edge_distance": edge_distance,
"edge_index": edge_index,
# returning this only because it's cast to long and
# we don't want to repeat this.
"atomic_numbers": atomic_numbers_full,
# TODO: this is only used by graph parallel to split up the partitions,
# should figure out cleaner way to pass this around to the heads
"node_offset": node_offset,
}


@registry.register_model("equiformer_v2_energy_head")
class EquiformerV2_OC20_energy_head(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same nit comment about camel caps and removing OC20 from the name.

def __init__(self, backbone, backbone_config, head_config):
super().__init__()
self.avg_num_nodes = backbone.avg_num_nodes
self.energy_block = FeedForwardNetwork(
backbone.sphere_channels,
backbone.ffn_hidden_channels,
1,
backbone.lmax_list,
backbone.mmax_list,
backbone.SO3_grid,
backbone.ffn_activation,
backbone.use_gate_act,
backbone.use_grid_mlp,
backbone.use_sep_s2_act,
)

def forward(self, x, emb):
node_energy = self.energy_block(emb["node_embedding"])
node_energy = node_energy.embedding.narrow(1, 0, 1)
if gp_utils.initialized():
node_energy = gp_utils.gather_from_model_parallel_region(node_energy, dim=0)
energy = torch.zeros(
len(x.natoms),
device=node_energy.device,
dtype=node_energy.dtype,
)
energy.index_add_(0, x.batch, node_energy.view(-1))
return {"energy": energy / self.avg_num_nodes}


@registry.register_model("equiformer_v2_force_head")
class EquiformerV2_OC20_force_head(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

camel caps and OC20

def __init__(self, backbone, backbone_config, head_config):
super().__init__()

self.force_block = SO2EquivariantGraphAttention(
backbone.sphere_channels,
backbone.attn_hidden_channels,
backbone.num_heads,
backbone.attn_alpha_channels,
backbone.attn_value_channels,
1,
backbone.lmax_list,
backbone.mmax_list,
backbone.SO3_rotation,
backbone.mappingReduced,
backbone.SO3_grid,
backbone.max_num_elements,
backbone.edge_channels_list,
backbone.block_use_atom_edge_embedding,
backbone.use_m_share_rad,
backbone.attn_activation,
backbone.use_s2_act_attn,
backbone.use_attn_renorm,
backbone.use_gate_act,
backbone.use_sep_s2_act,
alpha_drop=0.0,
)

def forward(self, x, emb):
forces = self.force_block(
emb["node_embedding"],
x.atomic_numbers.long(),
emb["edge_distance"],
emb["edge_index"],
node_offset=emb["node_offset"],
)
forces = forces.embedding.narrow(1, 1, 3)
forces = forces.view(-1, 3).contiguous()
if gp_utils.initialized():
forces = gp_utils.gather_from_model_parallel_region(forces, dim=0)
return {"forces": forces}
Loading