Skip to content

Commit

Permalink
datatype to TORCH_DTYPE for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Sep 12, 2024
1 parent 780b7a7 commit 5a2c7be
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 9 deletions.
8 changes: 4 additions & 4 deletions chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
except (ImportError, AttributeError):
make_graph = None

DATATYPE = torch.float32
TORCH_DTYPE = torch.float32


class CrystalGraphConverter(nn.Module):
Expand Down Expand Up @@ -124,10 +124,10 @@ def forward(
requires_grad=False,
)
atom_frac_coord = torch.tensor(
structure.frac_coords, dtype=DATATYPE, requires_grad=True
structure.frac_coords, dtype=TORCH_DTYPE, requires_grad=True
)
lattice = torch.tensor(
structure.lattice.matrix, dtype=DATATYPE, requires_grad=True
structure.lattice.matrix, dtype=TORCH_DTYPE, requires_grad=True
)
center_index, neighbor_index, image, distance = structure.get_neighbor_list(
r=self.atom_graph_cutoff, sites=structure.sites, numerical_tol=1e-8
Expand Down Expand Up @@ -177,7 +177,7 @@ def forward(
atomic_number=atomic_number,
atom_frac_coord=atom_frac_coord,
atom_graph=atom_graph,
neighbor_image=torch.tensor(image, dtype=DATATYPE),
neighbor_image=torch.tensor(image, dtype=TORCH_DTYPE),
directed2undirected=directed2undirected,
undirected2directed=undirected2directed,
bond_graph=bond_graph,
Expand Down
2 changes: 1 addition & 1 deletion chgnet/graph/crystalgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if TYPE_CHECKING:
from typing_extensions import Self

datatype = torch.float32
TORCH_DTYPE = torch.float32


class CrystalGraph:
Expand Down
6 changes: 3 additions & 3 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import Tensor, nn

from chgnet.graph import CrystalGraph, CrystalGraphConverter
from chgnet.graph.crystalgraph import datatype
from chgnet.graph.crystalgraph import TORCH_DTYPE
from chgnet.model.composition_model import AtomRef
from chgnet.model.encoders import AngleEncoder, AtomEmbedding, BondEncoder
from chgnet.model.functions import MLP, GatedMLP, find_normalization
Expand Down Expand Up @@ -808,7 +808,7 @@ def from_graphs(
if compute_stress:
strain = graph.lattice.new_zeros([3, 3], requires_grad=True)
lattice = graph.lattice @ (
torch.eye(3, dtype=datatype).to(strain.device) + strain
torch.eye(3, dtype=TORCH_DTYPE).to(strain.device) + strain
)
else:
strain = None
Expand Down Expand Up @@ -878,7 +878,7 @@ def from_graphs(
torch.cat(atom_owners, dim=0).type(torch.int32).to(atomic_numbers.device)
)
directed2undirected = torch.cat(directed2undirected, dim=0)
volumes = torch.tensor(volumes, dtype=datatype, device=atomic_numbers.device)
volumes = torch.tensor(volumes, dtype=TORCH_DTYPE, device=atomic_numbers.device)

return cls(
atomic_numbers=atomic_numbers,
Expand Down
1 change: 0 additions & 1 deletion examples/make_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from chgnet.data.dataset import StructureData, StructureJsonData
from chgnet.graph import CrystalGraphConverter

datatype = torch.float32
random.seed(100)


Expand Down

0 comments on commit 5a2c7be

Please sign in to comment.