diff --git a/ocpmodels/models/equiformer_v2/equiformer_v2_oc20.py b/ocpmodels/models/equiformer_v2/equiformer_v2_oc20.py index f8aa5a258..59fabeda5 100644 --- a/ocpmodels/models/equiformer_v2/equiformer_v2_oc20.py +++ b/ocpmodels/models/equiformer_v2/equiformer_v2_oc20.py @@ -95,6 +95,7 @@ class EquiformerV2_OC20(BaseModel): proj_drop (float): Dropout rate for outputs of attention and FFN in Transformer blocks weight_init (str): ['normal', 'uniform'] initialization of weights of linear layers except those in radial functions + enforce_max_neighbors_strictly (bool): When edges are subselected based on the `max_neighbors` arg, arbitrarily select amongst equidistant / degenerate edges to have exactly the correct number. """ def __init__( @@ -137,6 +138,7 @@ def __init__( drop_path_rate: float = 0.05, proj_drop: float = 0.0, weight_init: str = "normal", + enforce_max_neighbors_strictly: bool = True, ): super().__init__() @@ -198,6 +200,8 @@ def __init__( self.weight_init = weight_init assert self.weight_init in ["normal", "uniform"] + self.enforce_max_neighbors_strictly = enforce_max_neighbors_strictly + self.device = "cpu" # torch.cuda.current_device() self.grad_forces = False @@ -381,7 +385,10 @@ def forward(self, data): cell_offsets, _, # cell offset distances neighbors, - ) = self.generate_graph(data) + ) = self.generate_graph( + data, + enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, + ) ############################################################### # Initialize data structures