Skip to content

Commit

Permalink
rename module var datatype to TORCH_DTYPE
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Sep 12, 2024
1 parent 48b76c6 commit 780b7a7
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from chgnet import TrainTask

warnings.filterwarnings("ignore")
datatype = torch.float32
TORCH_DTYPE = torch.float32


class StructureData(Dataset):
Expand Down Expand Up @@ -163,21 +163,21 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict]:
struct, graph_id=graph_id, mp_id=mp_id
)
targets = {
"e": torch.tensor(self.energies[graph_id], dtype=datatype),
"f": torch.tensor(self.forces[graph_id], dtype=datatype),
"e": torch.tensor(self.energies[graph_id], dtype=TORCH_DTYPE),
"f": torch.tensor(self.forces[graph_id], dtype=TORCH_DTYPE),
}
if self.stresses is not None:
# Convert VASP stress
targets["s"] = torch.tensor(
self.stresses[graph_id], dtype=datatype
self.stresses[graph_id], dtype=TORCH_DTYPE
) * (-0.1)
if self.magmoms is not None:
mag = self.magmoms[graph_id]
# use absolute value for magnetic moments
if mag is None:
targets["m"] = None
else:
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
targets["m"] = torch.abs(torch.tensor(mag, dtype=TORCH_DTYPE))

return crystal_graph, targets

Expand Down Expand Up @@ -275,18 +275,18 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict[str, Tensor]]:
for key in self.targets:
if key == "e":
energy = self.data[graph_id][self.energy_key]
targets["e"] = torch.tensor(energy, dtype=datatype)
targets["e"] = torch.tensor(energy, dtype=TORCH_DTYPE)
elif key == "f":
force = self.data[graph_id][self.force_key]
targets["f"] = torch.tensor(force, dtype=datatype)
targets["f"] = torch.tensor(force, dtype=TORCH_DTYPE)
elif key == "s":
stress = self.data[graph_id][self.stress_key]
# Convert VASP stress
targets["s"] = torch.tensor(stress, dtype=datatype) * -0.1
targets["s"] = torch.tensor(stress, dtype=TORCH_DTYPE) * -0.1
elif key == "m":
mag = self.data[graph_id][self.magmom_key]
# use absolute value for magnetic moments
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
targets["m"] = torch.abs(torch.tensor(mag, dtype=TORCH_DTYPE))
return crystal_graph, targets

# Omit structures with isolated atoms.
Expand Down Expand Up @@ -404,21 +404,23 @@ def __getitem__(self, idx) -> tuple[CrystalGraph, dict[str, Tensor]]:
for key in self.targets:
if key == "e":
energy = self.labels[mp_id][graph_id][self.energy_key]
targets["e"] = torch.tensor(energy, dtype=datatype)
targets["e"] = torch.tensor(energy, dtype=TORCH_DTYPE)
elif key == "f":
force = self.labels[mp_id][graph_id][self.force_key]
targets["f"] = torch.tensor(force, dtype=datatype)
targets["f"] = torch.tensor(force, dtype=TORCH_DTYPE)
elif key == "s":
stress = self.labels[mp_id][graph_id][self.stress_key]
# Convert VASP stress
targets["s"] = torch.tensor(stress, dtype=datatype) * (-0.1)
targets["s"] = torch.tensor(stress, dtype=TORCH_DTYPE) * (-0.1)
elif key == "m":
mag = self.labels[mp_id][graph_id][self.magmom_key]
# use absolute value for magnetic moments
if mag is None:
targets["m"] = None
else:
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
targets["m"] = torch.abs(
torch.tensor(mag, dtype=TORCH_DTYPE)
)
return crystal_graph, targets

# Omit failed structures. Return another randomly selected structure
Expand Down Expand Up @@ -629,21 +631,23 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict[str, Tensor]]:
for key in self.targets:
if key == "e":
energy = self.data[mp_id][graph_id][self.energy_key]
targets["e"] = torch.tensor(energy, dtype=datatype)
targets["e"] = torch.tensor(energy, dtype=TORCH_DTYPE)
elif key == "f":
force = self.data[mp_id][graph_id][self.force_key]
targets["f"] = torch.tensor(force, dtype=datatype)
targets["f"] = torch.tensor(force, dtype=TORCH_DTYPE)
elif key == "s":
stress = self.data[mp_id][graph_id][self.stress_key]
# Convert VASP stress
targets["s"] = torch.tensor(stress, dtype=datatype) * (-0.1)
targets["s"] = torch.tensor(stress, dtype=TORCH_DTYPE) * (-0.1)
elif key == "m":
mag = self.data[mp_id][graph_id][self.magmom_key]
# use absolute value for magnetic moments
if mag is None:
targets["m"] = None
else:
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
targets["m"] = torch.abs(
torch.tensor(mag, dtype=TORCH_DTYPE)
)
return crystal_graph, targets

# Omit structures with isolated atoms. Return another randomly selected
Expand Down Expand Up @@ -773,7 +777,7 @@ def collate_graphs(batch_data: list) -> tuple[list[CrystalGraph], dict[str, Tens
graphs = [graph for graph, _ in batch_data]
all_targets = {key: [] for key in batch_data[0][1]}
all_targets["e"] = torch.tensor(
[targets["e"] for _, targets in batch_data], dtype=datatype
[targets["e"] for _, targets in batch_data], dtype=TORCH_DTYPE
)

for _, targets in batch_data:
Expand Down

0 comments on commit 780b7a7

Please sign in to comment.