Skip to content

Commit

Permalink
Support for passing in stats to Equiformer V2 model (#576)
Browse files Browse the repository at this point in the history
Co-authored-by: Abhishek Das <[email protected]>
  • Loading branch information
anuroopsriram and abhshkdz authored Sep 14, 2023
1 parent 0505017 commit 936a7be
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions ocpmodels/models/equiformer_v2/equiformer_v2_oc20.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def __init__(
proj_drop: float = 0.0,
weight_init: str = "normal",
enforce_max_neighbors_strictly: bool = True,
avg_num_nodes: Optional[float] = None,
avg_degree: Optional[float] = None,
):
super().__init__()

Expand Down Expand Up @@ -197,6 +199,9 @@ def __init__(
self.drop_path_rate = drop_path_rate
self.proj_drop = proj_drop

self.avg_num_nodes = avg_num_nodes or _AVG_NUM_NODES
self.avg_degree = avg_degree or _AVG_DEGREE

self.weight_init = weight_init
assert self.weight_init in ["normal", "uniform"]

Expand Down Expand Up @@ -286,7 +291,7 @@ def __init__(
self.max_num_elements,
self.edge_channels_list,
self.block_use_atom_edge_embedding,
rescale_factor=_AVG_DEGREE,
rescale_factor=self.avg_degree,
)

# Initialize the blocks for each layer of EquiformerV2
Expand Down Expand Up @@ -480,7 +485,7 @@ def forward(self, data):
dtype=node_energy.dtype,
)
energy.index_add_(0, data.batch, node_energy.view(-1))
energy = energy / _AVG_NUM_NODES
energy = energy / self.avg_num_nodes

###############################################################
# Force estimation
Expand Down

0 comments on commit 936a7be

Please sign in to comment.