diff --git a/chgnet/graph/converter.py b/chgnet/graph/converter.py index 4b26db59..3fb98f15 100644 --- a/chgnet/graph/converter.py +++ b/chgnet/graph/converter.py @@ -23,7 +23,7 @@ except (ImportError, AttributeError): make_graph = None -DATATYPE = torch.float32 +TORCH_DTYPE = torch.float32 class CrystalGraphConverter(nn.Module): @@ -124,10 +124,10 @@ def forward( requires_grad=False, ) atom_frac_coord = torch.tensor( - structure.frac_coords, dtype=DATATYPE, requires_grad=True + structure.frac_coords, dtype=TORCH_DTYPE, requires_grad=True ) lattice = torch.tensor( - structure.lattice.matrix, dtype=DATATYPE, requires_grad=True + structure.lattice.matrix, dtype=TORCH_DTYPE, requires_grad=True ) center_index, neighbor_index, image, distance = structure.get_neighbor_list( r=self.atom_graph_cutoff, sites=structure.sites, numerical_tol=1e-8 @@ -177,7 +177,7 @@ def forward( atomic_number=atomic_number, atom_frac_coord=atom_frac_coord, atom_graph=atom_graph, - neighbor_image=torch.tensor(image, dtype=DATATYPE), + neighbor_image=torch.tensor(image, dtype=TORCH_DTYPE), directed2undirected=directed2undirected, undirected2directed=undirected2directed, bond_graph=bond_graph, diff --git a/chgnet/graph/crystalgraph.py b/chgnet/graph/crystalgraph.py index 637b359a..4f4572d1 100644 --- a/chgnet/graph/crystalgraph.py +++ b/chgnet/graph/crystalgraph.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from typing_extensions import Self -datatype = torch.float32 +TORCH_DTYPE = torch.float32 class CrystalGraph: diff --git a/chgnet/model/model.py b/chgnet/model/model.py index d2030337..abca5b21 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -11,7 +11,7 @@ from torch import Tensor, nn from chgnet.graph import CrystalGraph, CrystalGraphConverter -from chgnet.graph.crystalgraph import datatype +from chgnet.graph.crystalgraph import TORCH_DTYPE from chgnet.model.composition_model import AtomRef from chgnet.model.encoders import AngleEncoder, AtomEmbedding, BondEncoder from chgnet.model.functions import MLP, GatedMLP, find_normalization @@ -808,7 +808,7 @@ def from_graphs( if compute_stress: strain = graph.lattice.new_zeros([3, 3], requires_grad=True) lattice = graph.lattice @ ( - torch.eye(3, dtype=datatype).to(strain.device) + strain + torch.eye(3, dtype=TORCH_DTYPE).to(strain.device) + strain ) else: strain = None @@ -878,7 +878,7 @@ def from_graphs( torch.cat(atom_owners, dim=0).type(torch.int32).to(atomic_numbers.device) ) directed2undirected = torch.cat(directed2undirected, dim=0) - volumes = torch.tensor(volumes, dtype=datatype, device=atomic_numbers.device) + volumes = torch.tensor(volumes, dtype=TORCH_DTYPE, device=atomic_numbers.device) return cls( atomic_numbers=atomic_numbers, diff --git a/examples/make_graphs.py b/examples/make_graphs.py index 8aacc2a5..15a5fcfe 100644 --- a/examples/make_graphs.py +++ b/examples/make_graphs.py @@ -10,7 +10,6 @@ from chgnet.data.dataset import StructureData, StructureJsonData from chgnet.graph import CrystalGraphConverter -datatype = torch.float32 random.seed(100)