Skip to content

Commit

Permalink
mypy fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyue Ping Ong committed Jun 20, 2023
1 parent 25766ec commit d15c2e3
Show file tree
Hide file tree
Showing 17 changed files with 70 additions and 65 deletions.
14 changes: 7 additions & 7 deletions matgl/apps/pes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class Potential(nn.Module, IOMixIn):
def __init__(
self,
model: nn.Module,
data_mean: torch.tensor | None = None,
data_std: torch.tensor | None = None,
data_mean: torch.Tensor | None = None,
data_std: torch.Tensor | None = None,
element_refs: np.ndarray | None = None,
calc_forces: bool = True,
calc_stresses: bool = True,
Expand Down Expand Up @@ -53,15 +53,15 @@ def __init__(
self.data_std = data_std if data_std is not None else torch.ones(1)

def forward(
self, g: dgl.DGLGraph, state_attr: torch.tensor | None = None, l_g: dgl.DGLGraph | None = None
self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, l_g: dgl.DGLGraph | None = None
) -> tuple:
"""Args:
g: DGL graph
state_attr: State attrs
l_g: Line graph.
Returns:
energies, forces, stresses, hessian: torch.tensor
energies, forces, stresses, hessian: torch.Tensor
"""
forces = torch.zeros(1)
stresses = torch.zeros(1)
Expand Down Expand Up @@ -93,13 +93,13 @@ def forward(
hessian[iatom] = tmp.view(-1)
if self.calc_stresses:
f_ij = -grads[1]
stresses = []
sts: list = []
count_edge = 0
count_node = 0
for graph_id in range(g.batch_size):
num_edges = g.batch_num_edges()[graph_id]
num_nodes = 0
stresses.append(
sts.append(
-1
* (
160.21766208
Expand All @@ -113,6 +113,6 @@ def forward(
count_edge = count_edge + num_edges
num_nodes = g.batch_num_nodes()[graph_id]
count_node = count_node + num_nodes
stresses = torch.cat(stresses)
stresses = torch.cat(sts)

return total_energies, forces, stresses, hessian
10 changes: 5 additions & 5 deletions matgl/ext/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class M3GNetCalculator(Calculator):
def __init__(
self,
potential: Potential,
state_attr: torch.tensor = None,
state_attr: torch.Tensor = None,
stress_weight: float = 1.0,
**kwargs,
):
Expand Down Expand Up @@ -179,7 +179,7 @@ class Relaxer:
def __init__(
self,
potential: Potential = None,
state_attr: torch.tensor = None,
state_attr: torch.Tensor = None,
optimizer: Optimizer | str = "FIRE",
relax_cell: bool = True,
stress_weight: float = 0.01,
Expand All @@ -188,7 +188,7 @@ def __init__(
Args:
potential (Potential): a M3GNet potential, a str path to a saved model or a short name for saved model
that comes with M3GNet distribution
state_attr (torch.tensor): State attr.
state_attr (torch.Tensor): State attr.
optimizer (str or ase Optimizer): the optimization algorithm.
Defaults to "FIRE"
relax_cell (bool): whether to relax the lattice cell
Expand Down Expand Up @@ -312,7 +312,7 @@ def __init__(
self,
atoms: Atoms,
potential: Potential,
state_attr: torch.tensor = None,
state_attr: torch.Tensor = None,
ensemble: str = "nvt",
temperature: int = 300,
timestep: float = 1.0,
Expand All @@ -332,7 +332,7 @@ def __init__(
atoms (Atoms): atoms to run the MD
potential (Potential): potential for calculating the energy, force,
stress of the atoms
state_attr (torch.tensor): State attr.
state_attr (torch.Tensor): State attr.
ensemble (str): choose from 'nvt' or 'npt'. NPT is not tested,
use with extra caution
temperature (float): temperature for MD simulation, in K
Expand Down
2 changes: 1 addition & 1 deletion matgl/ext/pymatgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_element_list(train_structures: list[Structure | Molecule]) -> tuple[str]
Returns:
Tuple of elements covered in training set
"""
elements = set()
elements: set[str] = set()
for s in train_structures:
elements.update(s.composition.get_el_amt_dict().keys())
return tuple(sorted(elements, key=lambda el: Element(el).Z)) # type: ignore
Expand Down
8 changes: 4 additions & 4 deletions matgl/graph/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def compute_3body(g: dgl.DGLGraph):
n_triple_i = n_bond_per_atom * (n_bond_per_atom - 1)
n_triple = torch.sum(n_triple_i)
n_triple_ij = (n_bond_per_atom - 1).repeat_interleave(n_bond_per_atom)
triple_bond_indices = torch.empty((n_triple, 2), dtype=torch.int64)
triple_bond_indices = torch.empty((n_triple, 2), dtype=torch.int64) # type: ignore

start = 0
cs = 0
Expand Down Expand Up @@ -64,7 +64,7 @@ def compute_3body(g: dgl.DGLGraph):
l_g.ndata["bond_vec"] = g.edata["bond_vec"]
l_g.ndata["pbc_offset"] = g.edata["pbc_offset"]
l_g.ndata["n_triple_ij"] = n_triple_ij
n_triple_s = torch.tensor(n_triple_s, dtype=torch.int64)
n_triple_s = torch.tensor(n_triple_s, dtype=torch.int64) # type: ignore
return l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s


Expand Down Expand Up @@ -97,8 +97,8 @@ def compute_theta_and_phi(edges: dgl.udf.EdgeBatch):
edges: DGL graph edges
Returns:
cos_theta: torch.tensor
phi: torch.tensor
cos_theta: torch.Tensor
phi: torch.Tensor
triple_bond_lengths (torch.tensor):
"""
vec1 = edges.src["bond_vec"]
Expand Down
13 changes: 8 additions & 5 deletions matgl/graph/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,18 @@ def load(
filename_line_graph: str = "dgl_line_graph.bin",
filename_state_attr: str = "state_attr.pt",
):
"""Load dgl graphs
"""
Load dgl graphs from files.
Args:
:filename: Name of file storing dgl graphs
:filename: Name of file storing state attrs.
filename: Name of file storing dgl graphs
filename_line_graph: Name of file storing dgl line graphs
filename_state_attr: Name of file storing state attrs.
"""
self.graphs = load_graphs(filename)
self.line_graphs = load_graphs(filename_line_graph)
with open("labels.json") as file:
labels = json.load(file)
labels: dict = json.load(file)
self.energies = labels["energies"]
self.forces = labels["forces"]
self.stresses = labels["stresses"]
Expand All @@ -316,7 +319,7 @@ def __getitem__(self, idx: int):
self.state_attr[idx],
self.energies[idx],
torch.tensor(self.forces[idx]),
torch.tensor(self.stresses[idx]),
torch.tensor(self.stresses[idx]), # type: ignore
)

def __len__(self):
Expand Down
4 changes: 2 additions & 2 deletions matgl/layers/_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self) -> None:
super().__init__()
self.ssp = nn.Softplus()

def forward(self, x: torch.tensor) -> torch.tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Evaluate activation function given the input tensor x.
Args:
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(self, alpha: float = None):

self.alpha.requires_grad_(True)

def forward(self, x: torch.tensor) -> torch.tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Evaluate activation function given the input tensor x.
Args:
Expand Down
2 changes: 1 addition & 1 deletion matgl/layers/_atom_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def fit(self, graphs: list, properties: np.typing.NDArray) -> None:
self.property_offset = np.linalg.pinv(features.T.dot(features)).dot(features.T.dot(properties))
self.property_offset = torch.tensor(self.property_offset)

def forward(self, g: dgl.DGLGraph, state_attr: torch.tensor | None = None):
def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None):
"""Get the total property offset for a system.
Args:
Expand Down
12 changes: 6 additions & 6 deletions matgl/layers/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _calculate_smooth_symbolic_funcs(self) -> list:

def __call__(self, r):
"""Args:
r: torch.tensor, distance tensor, 1D.
r: torch.Tensor, distance tensor, 1D.
Returns: [n, max_n * max_l] spherical Bessel function results
Expand Down Expand Up @@ -128,7 +128,7 @@ def rbf_j0(r, cutoff: float = 5.0, max_n: int = 3):
vanishes at cutoff.
Args:
r: torch.tensor pytorch tensors
r: torch.Tensor pytorch tensors
cutoff: float, the cutoff radius
max_n: int max number of basis
Returns: basis function expansion using first spherical Bessel function
Expand Down Expand Up @@ -166,7 +166,7 @@ def __init__(self, max_l: int, use_phi: bool = True):
def __call__(self, costheta, phi=None):
"""Args:
costheta: Cosine of the azimuthal angle
phi: torch.tensor, the polar angle.
phi: torch.Tensor, the polar angle.
Returns: [n, m] spherical harmonic results, where n is the number
of angles. The column is arranged following
Expand All @@ -186,8 +186,8 @@ def _y00(theta, phi):
Y_0^0 = \frac{1}{2} \sqrt{\frac{1}{\pi}}
Args:
theta: torch.tensor, the azimuthal angle
phi: torch.tensor, the polar angle
theta: torch.Tensor, the azimuthal angle
phi: torch.Tensor, the polar angle
Returns: `Y_0^0` results
Expand All @@ -209,7 +209,7 @@ def spherical_bessel_smooth(r, cutoff: float = 5.0, max_n: int = 10):
https://arxiv.org/pdf/1907.02374.pdf
Args:
r: torch.tensor distance tensor
r: torch.Tensor distance tensor
cutoff: float, cutoff radius
max_n: int, max number of basis, expanded by the zero roots
Expand Down
2 changes: 1 addition & 1 deletion matgl/layers/_bond.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
else:
raise Exception("undefined rbf_type, please use SphericalBessel or Gaussian instead.")

def forward(self, bond_dist: torch.tensor):
def forward(self, bond_dist: torch.Tensor):
"""Forward.
Args:
Expand Down
4 changes: 2 additions & 2 deletions matgl/layers/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(self, in_feats: int, dims: list[int], activate_last: bool = True, u
self.gates.append(nn.Linear(in_dim, out_dim, bias=use_bias))
self.gates.append(nn.Sigmoid())

def forward(self, inputs: torch.tensor):
def forward(self, inputs: torch.Tensor):
return self.layers(inputs) * self.gates(inputs)


Expand All @@ -144,7 +144,7 @@ def reset_parameters(self):
"""Reinitialize learnable parameters."""
self.lstm.reset_parameters()

def forward(self, g: DGLGraph, feat: torch.tensor):
def forward(self, g: DGLGraph, feat: torch.Tensor):
"""Defines the computation performed at every call.
:param g: Input graph
Expand Down
6 changes: 3 additions & 3 deletions matgl/layers/_graph_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def __init__(
"activate_last": True,
"bias_last": True,
}
self.edge_func = MLP(**mlp_kwargs) if self.has_dense else Identity()
self.node_func = MLP(**mlp_kwargs) if self.has_dense else Identity()
self.state_func = MLP(**mlp_kwargs) if self.has_dense else Identity()
self.edge_func = MLP(**mlp_kwargs) if self.has_dense else Identity() # type: ignore
self.node_func = MLP(**mlp_kwargs) if self.has_dense else Identity() # type: ignore
self.state_func = MLP(**mlp_kwargs) if self.has_dense else Identity() # type: ignore

# compute input sizes
edge_in = 2 * conv_dim + conv_dim + conv_dim # 2*NDIM+EDIM+GDIM
Expand Down
2 changes: 1 addition & 1 deletion matgl/layers/_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(self, g: dgl.DGLGraph):
"""Args:
g: DGL graph
Returns:
atomic_properties: torch.tensor.
atomic_properties: torch.Tensor.
"""
atomic_properties = self.gated(g.ndata["node_feat"])
return atomic_properties
Expand Down
14 changes: 7 additions & 7 deletions matgl/layers/_three_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def forward(
self,
graph: dgl.DGLGraph,
line_graph: dgl.DGLGraph,
three_basis: torch.tensor,
three_basis: torch.Tensor,
three_cutoff: float,
node_feat: torch.tensor,
edge_feat: torch.tensor,
node_feat: torch.Tensor,
edge_feat: torch.Tensor,
):
"""
Forward function for ThreeBodyInteractions.
Expand All @@ -54,11 +54,11 @@ def forward(
end_atom_index = torch.unsqueeze(end_atom_index, 1)
atoms = torch.squeeze(atoms[end_atom_index])
basis = three_basis * atoms
three_cutoff = torch.unsqueeze(three_cutoff, dim=1)
three_cutoff = torch.unsqueeze(three_cutoff, dim=1) # type: ignore
weights = torch.reshape(
three_cutoff[torch.stack(list(line_graph.edges()), dim=1).to(torch.int64)], (-1, 2) # type: ignore
)
weights = torch.prod(weights, axis=-1)
weights = torch.prod(weights, axis=-1) # type: ignore
basis = basis * weights[:, None]
new_bonds = scatter_sum(
basis.to(torch.float32),
Expand All @@ -83,8 +83,8 @@ def combine_sbf_shf(sbf, shf, max_n: int, max_l: int, use_phi: bool):
[m=[0], m=[0], ...] max_l columns
Args:
sbf: torch.tensor spherical bessel function results
shf: torch.tensor spherical harmonics function results
sbf: torch.Tensor spherical bessel function results
shf: torch.Tensor spherical harmonics function results
max_n: int, max number of n
max_l: int, max number of l
use_phi: whether to use phi
Expand Down
12 changes: 7 additions & 5 deletions matgl/models/_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(
self.readout = Set2SetReadOut(num_steps=niters_set2set, num_layers=nlayers_set2set, field=field)
readout_feats = 2 * input_feats + dim_state_feats if include_state else 2 * input_feats # type: ignore
else:
self.readout = ReduceReadOut("mean", field=field)
self.readout = ReduceReadOut("mean", field=field) # type: ignore
readout_feats = input_feats + dim_state_feats if include_state else input_feats # type: ignore

dims_final_layer = [readout_feats, units, units, ntargets]
Expand All @@ -191,7 +191,9 @@ def __init__(
else:
if task_type == "classification":
raise ValueError("Classification task cannot be extensive")
self.final_layer = WeightedReadOut(in_feats=dim_node_embedding, dims=[units, units], num_targets=ntargets)
self.final_layer = WeightedReadOut(
in_feats=dim_node_embedding, dims=[units, units], num_targets=ntargets # type: ignore
)

self.max_n = max_n
self.max_l = max_l
Expand All @@ -203,7 +205,7 @@ def __init__(
self.task_type = task_type
self.is_intensive = is_intensive

def forward(self, g: dgl.DGLGraph, state_attr: torch.tensor | None = None, l_g: dgl.DGLGraph | None = None):
def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, l_g: dgl.DGLGraph | None = None):
"""Performs message passing and updates node representations.
Args:
Expand Down Expand Up @@ -243,7 +245,7 @@ def forward(self, g: dgl.DGLGraph, state_attr: torch.tensor | None = None, l_g:
g.edata["edge_feat"] = num_edge_feats
if self.is_intensive:
node_vec = self.readout(g)
vec = torch.hstack([node_vec, state_attr]) if self.include_states else node_vec
vec = torch.hstack([node_vec, state_attr]) if self.include_states else node_vec # type: ignore
output = self.final_layer(vec)
if self.task_type == "classification":
output = self.sigmoid(output)
Expand All @@ -253,7 +255,7 @@ def forward(self, g: dgl.DGLGraph, state_attr: torch.tensor | None = None, l_g:
return torch.squeeze(output)

def predict_structure(
self, structure, state_feats: torch.tensor | None = None, graph_converter: GraphConverter | None = None
self, structure, state_feats: torch.Tensor | None = None, graph_converter: GraphConverter | None = None
):
"""Convenience method to directly predict property from structure.
Expand Down
2 changes: 1 addition & 1 deletion matgl/models/_megnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def forward(
def predict_structure(
self,
structure,
state_feats: torch.tensor | None = None,
state_feats: torch.Tensor | None = None,
graph_converter: GraphConverter | None = None,
):
"""Convenience method to directly predict property from structure.
Expand Down
Loading

0 comments on commit d15c2e3

Please sign in to comment.