Skip to content

Commit

Permalink
More type annotations (#544)
Browse files Browse the repository at this point in the history
  • Loading branch information
r-barnes committed Jul 28, 2023
1 parent cfd9b33 commit b700b7c
Show file tree
Hide file tree
Showing 42 changed files with 577 additions and 509 deletions.
15 changes: 9 additions & 6 deletions ocpmodels/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import os
import subprocess
from typing import List

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -106,19 +107,19 @@ def cleanup() -> None:
dist.destroy_process_group()


def initialized():
def initialized() -> bool:
return dist.is_available() and dist.is_initialized()


def get_rank():
def get_rank() -> int:
return dist.get_rank() if initialized() else 0


def get_world_size():
def get_world_size() -> int:
return dist.get_world_size() if initialized() else 1


def is_master():
def is_master() -> bool:
return get_rank() == 0


Expand All @@ -138,7 +139,7 @@ def broadcast(

def all_reduce(
data, group=dist.group.WORLD, average: bool = False, device=None
):
) -> torch.Tensor:
if get_world_size() == 1:
return data
tensor = data
Expand All @@ -156,7 +157,9 @@ def all_reduce(
return result


def all_gather(data, group=dist.group.WORLD, device=None):
def all_gather(
data, group=dist.group.WORLD, device=None
) -> List[torch.Tensor]:
if get_world_size() == 1:
return data
tensor = data
Expand Down
8 changes: 4 additions & 4 deletions ocpmodels/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def __init__(
config_yml=None,
checkpoint=None,
trainer=None,
cutoff=6,
max_neighbors=50,
cpu=True,
cutoff: int = 6,
max_neighbors: int = 50,
cpu: bool = True,
) -> None:
"""
OCP-ASE Calculator
Expand Down Expand Up @@ -194,7 +194,7 @@ def load_checkpoint(self, checkpoint_path: str) -> None:
except NotImplementedError:
logging.warning("Unable to load checkpoint!")

def calculate(self, atoms, properties, system_changes) -> None:
def calculate(self, atoms: Atoms, properties, system_changes) -> None:
Calculator.calculate(self, atoms, properties, system_changes)
data_object = self.a2g.convert(atoms)
batch = data_list_collater([data_object], otf_graph=True)
Expand Down
17 changes: 11 additions & 6 deletions ocpmodels/common/relaxation/ml_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import logging
from collections import deque
from pathlib import Path
from typing import Optional

import torch
from torch_geometric.data import Batch

from ocpmodels.common.typing import assert_is_instance
from ocpmodels.datasets.lmdb_dataset import data_list_collater

from .optimizers.lbfgs_torch import LBFGS, TorchCalc
Expand All @@ -20,13 +22,13 @@
def ml_relax(
batch,
model,
steps,
fmax,
steps: int,
fmax: float,
relax_opt,
save_full_traj,
device="cuda:0",
device: str = "cuda:0",
transform=None,
early_stop_batch=False,
early_stop_batch: bool = False,
):
"""
Runs ML-based relaxations.
Expand Down Expand Up @@ -66,18 +68,21 @@ def ml_relax(
traj_names=ids,
early_stop_batch=early_stop_batch,
)

e: Optional[RuntimeError] = None
try:
relaxed_batch = optimizer.run(fmax=fmax, steps=steps)
relaxed_batches.append(relaxed_batch)
except RuntimeError as e:
except RuntimeError as err:
e = err
oom = True
torch.cuda.empty_cache()

if oom:
# move OOM recovery code outside of except clause to allow tensors to be freed.
data_list = batch.to_data_list()
if len(data_list) == 1:
raise e
raise assert_is_instance(e, RuntimeError)
logging.info(
f"Failed to relax batch with size: {len(data_list)}, splitting into two..."
)
Expand Down
8 changes: 4 additions & 4 deletions ocpmodels/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def collate(data_list):
def add_edge_distance_to_graph(
batch,
device="cpu",
dmin=0.0,
dmax=6.0,
num_gaussians=50,
dmin: float = 0.0,
dmax: float = 6.0,
num_gaussians: int = 50,
):
# Make sure x has positions.
if not all(batch.pos[0][:] == batch.x[0][-3:]):
Expand Down Expand Up @@ -454,7 +454,7 @@ def build_config(args, args_override):
return config


def create_grid(base_config, sweep_file):
def create_grid(base_config, sweep_file: str):
def _flatten_sweeps(sweeps, root_key: str = "", sep: str = "."):
flat_sweeps = []
for key, value in sweeps.items():
Expand Down
2 changes: 1 addition & 1 deletion ocpmodels/datasets/ase_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


def apply_one_tags(
atoms, skip_if_nonzero: bool = True, skip_always: bool = False
atoms: ase.Atoms, skip_if_nonzero: bool = True, skip_always: bool = False
):
"""
This function will apply tags of 1 to an ASE atoms object.
Expand Down
3 changes: 3 additions & 0 deletions ocpmodels/datasets/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
@registry.register_dataset("single_point_lmdb")
@registry.register_dataset("trajectory_lmdb")
class LmdbDataset(Dataset[T_co]):
metadata_path: Path
sharded: bool

r"""Dataset class to load from LMDB files containing relaxation
trajectories or single point computations.
Useful for Structure to Energy & Force (S2EF), Initial State to
Expand Down
56 changes: 28 additions & 28 deletions ocpmodels/models/dimenet_plus_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@
class InteractionPPBlock(torch.nn.Module):
def __init__(
self,
hidden_channels,
int_emb_size,
basis_emb_size,
num_spherical,
num_radial,
num_before_skip,
num_after_skip,
hidden_channels: int,
int_emb_size: int,
basis_emb_size: int,
num_spherical: int,
num_radial: int,
num_before_skip: int,
num_after_skip: int,
act="silu",
) -> None:
act = activation_resolver(act)
Expand Down Expand Up @@ -163,9 +163,9 @@ class OutputPPBlock(torch.nn.Module):
def __init__(
self,
num_radial: int,
hidden_channels,
out_emb_channels,
out_channels,
hidden_channels: int,
out_emb_channels: int,
out_channels: int,
num_layers: int,
act: str = "silu",
) -> None:
Expand Down Expand Up @@ -340,24 +340,24 @@ def forward(self, z, pos, batch=None):
class DimeNetPlusPlusWrap(DimeNetPlusPlus, BaseModel):
def __init__(
self,
num_atoms,
bond_feat_dim, # not used
num_targets,
use_pbc=True,
regress_forces=True,
hidden_channels=128,
num_blocks=4,
int_emb_size=64,
basis_emb_size=8,
out_emb_channels=256,
num_spherical=7,
num_radial=6,
otf_graph=False,
cutoff=10.0,
envelope_exponent=5,
num_before_skip=1,
num_after_skip=2,
num_output_layers=3,
num_atoms: int,
bond_feat_dim: int, # not used
num_targets: int,
use_pbc: bool = True,
regress_forces: bool = True,
hidden_channels: int = 128,
num_blocks: int = 4,
int_emb_size: int = 64,
basis_emb_size: int = 8,
out_emb_channels: int = 256,
num_spherical: int = 7,
num_radial: int = 6,
otf_graph: bool = False,
cutoff: float = 10.0,
envelope_exponent: int = 5,
num_before_skip: int = 1,
num_after_skip: int = 2,
num_output_layers: int = 3,
) -> None:
self.num_targets = num_targets
self.regress_forces = regress_forces
Expand Down
28 changes: 14 additions & 14 deletions ocpmodels/models/equiformer_v2/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class ScaledSiLU(nn.Module):
def __init__(self, inplace=False):
def __init__(self, inplace: bool = False) -> None:
super(ScaledSiLU, self).__init__()
self.inplace = inplace
self.scale_factor = 1.6791767923989418
Expand All @@ -21,7 +21,7 @@ def extra_repr(self):

# Reference: https://github.com/facebookresearch/llama/blob/main/llama/model.py#L175
class ScaledSwiGLU(nn.Module):
def __init__(self, in_channels, out_channels, bias=True):
def __init__(self, in_channels, out_channels, bias: bool = True) -> None:
super(ScaledSwiGLU, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
Expand All @@ -39,7 +39,7 @@ def forward(self, inputs):

# Reference: https://github.com/facebookresearch/llama/blob/main/llama/model.py#L175
class SwiGLU(nn.Module):
def __init__(self, in_channels, out_channels, bias=True):
def __init__(self, in_channels, out_channels, bias: bool = True) -> None:
super(SwiGLU, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
Expand All @@ -56,7 +56,7 @@ def forward(self, inputs):


class SmoothLeakyReLU(torch.nn.Module):
def __init__(self, negative_slope=0.2):
def __init__(self, negative_slope: float = 0.2) -> None:
super().__init__()
self.alpha = negative_slope

Expand All @@ -70,7 +70,7 @@ def extra_repr(self):


class ScaledSmoothLeakyReLU(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.act = SmoothLeakyReLU(0.2)
self.scale_factor = 1.531320475574866
Expand All @@ -85,7 +85,7 @@ def extra_repr(self):


class ScaledSigmoid(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.scale_factor = 1.8467055342154763

Expand All @@ -94,7 +94,7 @@ def forward(self, x):


class GateActivation(torch.nn.Module):
def __init__(self, lmax, mmax, num_channels):
def __init__(self, lmax, mmax, num_channels) -> None:
super().__init__()

self.lmax = lmax
Expand All @@ -103,14 +103,14 @@ def __init__(self, lmax, mmax, num_channels):

# compute `expand_index` based on `lmax` and `mmax`
num_components = 0
for l in range(1, self.lmax + 1):
num_m_components = min((2 * l + 1), (2 * self.mmax + 1))
for lval in range(1, self.lmax + 1):
num_m_components = min((2 * lval + 1), (2 * self.mmax + 1))
num_components = num_components + num_m_components
expand_index = torch.zeros([num_components]).long()
start_idx = 0
for l in range(1, self.lmax + 1):
length = min((2 * l + 1), (2 * self.mmax + 1))
expand_index[start_idx : (start_idx + length)] = l - 1
for lval in range(1, self.lmax + 1):
length = min((2 * lval + 1), (2 * self.mmax + 1))
expand_index[start_idx : (start_idx + length)] = lval - 1
start_idx = start_idx + length
self.register_buffer("expand_index", expand_index)

Expand Down Expand Up @@ -153,7 +153,7 @@ class S2Activation(torch.nn.Module):
Assume we only have one resolution
"""

def __init__(self, lmax, mmax):
def __init__(self, lmax, mmax) -> None:
super().__init__()
self.lmax = lmax
self.mmax = mmax
Expand All @@ -173,7 +173,7 @@ def forward(self, inputs, SO3_grid):


class SeparableS2Activation(torch.nn.Module):
def __init__(self, lmax, mmax):
def __init__(self, lmax, mmax) -> None:
super().__init__()

self.lmax = lmax
Expand Down
10 changes: 5 additions & 5 deletions ocpmodels/models/equiformer_v2/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

def __init__(self, drop_prob=None):
def __init__(self, drop_prob=None) -> None:
super(DropPath, self).__init__()
self.drop_prob = drop_prob

Expand All @@ -51,7 +51,7 @@ class GraphDropPath(nn.Module):
Consider batch for graph data when dropping paths.
"""

def __init__(self, drop_prob=None):
def __init__(self, drop_prob=None) -> None:
super(GraphDropPath, self).__init__()
self.drop_prob = drop_prob

Expand All @@ -70,7 +70,7 @@ def extra_repr(self):


class EquivariantDropout(nn.Module):
def __init__(self, irreps, drop_prob):
def __init__(self, irreps, drop_prob) -> None:
super(EquivariantDropout, self).__init__()
self.irreps = irreps
self.num_irreps = irreps.num_irreps
Expand All @@ -91,7 +91,7 @@ def forward(self, x):


class EquivariantScalarsDropout(nn.Module):
def __init__(self, irreps, drop_prob):
def __init__(self, irreps, drop_prob) -> None:
super(EquivariantScalarsDropout, self).__init__()
self.irreps = irreps
self.drop_prob = drop_prob
Expand All @@ -117,7 +117,7 @@ def extra_repr(self):


class EquivariantDropoutArraySphericalHarmonics(nn.Module):
def __init__(self, drop_prob, drop_graph=False):
def __init__(self, drop_prob, drop_graph: bool = False) -> None:
super(EquivariantDropoutArraySphericalHarmonics, self).__init__()
self.drop_prob = drop_prob
self.drop = torch.nn.Dropout(drop_prob, True)
Expand Down
Loading

0 comments on commit b700b7c

Please sign in to comment.