Skip to content

Commit

Permalink
Adds flag to configure strict no. of nbrs or not in EquiformerV2 (#571)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhshkdz authored Sep 1, 2023
1 parent d763606 commit d1ba3b8
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion ocpmodels/models/equiformer_v2/equiformer_v2_oc20.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class EquiformerV2_OC20(BaseModel):
proj_drop (float): Dropout rate for outputs of attention and FFN in Transformer blocks
weight_init (str): ['normal', 'uniform'] initialization of weights of linear layers except those in radial functions
enforce_max_neighbors_strictly (bool): When edges are subselected based on the `max_neighbors` arg, arbitrarily select amongst equidistant / degenerate edges to have exactly the correct number.
"""

def __init__(
Expand Down Expand Up @@ -137,6 +138,7 @@ def __init__(
drop_path_rate: float = 0.05,
proj_drop: float = 0.0,
weight_init: str = "normal",
enforce_max_neighbors_strictly: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -198,6 +200,8 @@ def __init__(
self.weight_init = weight_init
assert self.weight_init in ["normal", "uniform"]

self.enforce_max_neighbors_strictly = enforce_max_neighbors_strictly

self.device = "cpu" # torch.cuda.current_device()

self.grad_forces = False
Expand Down Expand Up @@ -381,7 +385,10 @@ def forward(self, data):
cell_offsets,
_, # cell offset distances
neighbors,
) = self.generate_graph(data)
) = self.generate_graph(
data,
enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly,
)

###############################################################
# Initialize data structures
Expand Down

0 comments on commit d1ba3b8

Please sign in to comment.