Skip to content

Commit

Permalink
Support NumPy 2 (#202)
Browse files Browse the repository at this point in the history
Fixed data type compatibility for support of Numpy 2

---------

Co-authored-by: BowenD-UCB <[email protected]>
  • Loading branch information
DanielYang59 and BowenD-UCB authored Sep 12, 2024
1 parent 9281cf4 commit 035abf2
Show file tree
Hide file tree
Showing 13 changed files with 140 additions and 121 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ jobs:
pip install uv
uv pip install -e .[test,logging] --resolution=${{ matrix.version.resolution }} --system
# TODO: remove pin once reverse readline fixed
uv pip install monty==2024.7.12 --system
- name: Run Tests
run: pytest --capture=no --cov --cov-report=xml
env:
Expand Down
40 changes: 22 additions & 18 deletions chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from chgnet import TrainTask

warnings.filterwarnings("ignore")
datatype = torch.float32
TORCH_DTYPE = torch.float32


class StructureData(Dataset):
Expand Down Expand Up @@ -163,21 +163,21 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict]:
struct, graph_id=graph_id, mp_id=mp_id
)
targets = {
"e": torch.tensor(self.energies[graph_id], dtype=datatype),
"f": torch.tensor(self.forces[graph_id], dtype=datatype),
"e": torch.tensor(self.energies[graph_id], dtype=TORCH_DTYPE),
"f": torch.tensor(self.forces[graph_id], dtype=TORCH_DTYPE),
}
if self.stresses is not None:
# Convert VASP stress
targets["s"] = torch.tensor(
self.stresses[graph_id], dtype=datatype
self.stresses[graph_id], dtype=TORCH_DTYPE
) * (-0.1)
if self.magmoms is not None:
mag = self.magmoms[graph_id]
# use absolute value for magnetic moments
if mag is None:
targets["m"] = None
else:
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
targets["m"] = torch.abs(torch.tensor(mag, dtype=TORCH_DTYPE))

return crystal_graph, targets

Expand Down Expand Up @@ -275,18 +275,18 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict[str, Tensor]]:
for key in self.targets:
if key == "e":
energy = self.data[graph_id][self.energy_key]
targets["e"] = torch.tensor(energy, dtype=datatype)
targets["e"] = torch.tensor(energy, dtype=TORCH_DTYPE)
elif key == "f":
force = self.data[graph_id][self.force_key]
targets["f"] = torch.tensor(force, dtype=datatype)
targets["f"] = torch.tensor(force, dtype=TORCH_DTYPE)
elif key == "s":
stress = self.data[graph_id][self.stress_key]
# Convert VASP stress
targets["s"] = torch.tensor(stress, dtype=datatype) * -0.1
targets["s"] = torch.tensor(stress, dtype=TORCH_DTYPE) * -0.1
elif key == "m":
mag = self.data[graph_id][self.magmom_key]
# use absolute value for magnetic moments
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
targets["m"] = torch.abs(torch.tensor(mag, dtype=TORCH_DTYPE))
return crystal_graph, targets

# Omit structures with isolated atoms.
Expand Down Expand Up @@ -404,21 +404,23 @@ def __getitem__(self, idx) -> tuple[CrystalGraph, dict[str, Tensor]]:
for key in self.targets:
if key == "e":
energy = self.labels[mp_id][graph_id][self.energy_key]
targets["e"] = torch.tensor(energy, dtype=datatype)
targets["e"] = torch.tensor(energy, dtype=TORCH_DTYPE)
elif key == "f":
force = self.labels[mp_id][graph_id][self.force_key]
targets["f"] = torch.tensor(force, dtype=datatype)
targets["f"] = torch.tensor(force, dtype=TORCH_DTYPE)
elif key == "s":
stress = self.labels[mp_id][graph_id][self.stress_key]
# Convert VASP stress
targets["s"] = torch.tensor(stress, dtype=datatype) * (-0.1)
targets["s"] = torch.tensor(stress, dtype=TORCH_DTYPE) * (-0.1)
elif key == "m":
mag = self.labels[mp_id][graph_id][self.magmom_key]
# use absolute value for magnetic moments
if mag is None:
targets["m"] = None
else:
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
targets["m"] = torch.abs(
torch.tensor(mag, dtype=TORCH_DTYPE)
)
return crystal_graph, targets

# Omit failed structures. Return another randomly selected structure
Expand Down Expand Up @@ -629,21 +631,23 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict[str, Tensor]]:
for key in self.targets:
if key == "e":
energy = self.data[mp_id][graph_id][self.energy_key]
targets["e"] = torch.tensor(energy, dtype=datatype)
targets["e"] = torch.tensor(energy, dtype=TORCH_DTYPE)
elif key == "f":
force = self.data[mp_id][graph_id][self.force_key]
targets["f"] = torch.tensor(force, dtype=datatype)
targets["f"] = torch.tensor(force, dtype=TORCH_DTYPE)
elif key == "s":
stress = self.data[mp_id][graph_id][self.stress_key]
# Convert VASP stress
targets["s"] = torch.tensor(stress, dtype=datatype) * (-0.1)
targets["s"] = torch.tensor(stress, dtype=TORCH_DTYPE) * (-0.1)
elif key == "m":
mag = self.data[mp_id][graph_id][self.magmom_key]
# use absolute value for magnetic moments
if mag is None:
targets["m"] = None
else:
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
targets["m"] = torch.abs(
torch.tensor(mag, dtype=TORCH_DTYPE)
)
return crystal_graph, targets

# Omit structures with isolated atoms. Return another randomly selected
Expand Down Expand Up @@ -773,7 +777,7 @@ def collate_graphs(batch_data: list) -> tuple[list[CrystalGraph], dict[str, Tens
graphs = [graph for graph, _ in batch_data]
all_targets = {key: [] for key in batch_data[0][1]}
all_targets["e"] = torch.tensor(
[targets["e"] for _, targets in batch_data], dtype=datatype
[targets["e"] for _, targets in batch_data], dtype=TORCH_DTYPE
)

for _, targets in batch_data:
Expand Down
10 changes: 5 additions & 5 deletions chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
except (ImportError, AttributeError):
make_graph = None

DATATYPE = torch.float32
TORCH_DTYPE = torch.float32


class CrystalGraphConverter(nn.Module):
Expand Down Expand Up @@ -124,10 +124,10 @@ def forward(
requires_grad=False,
)
atom_frac_coord = torch.tensor(
structure.frac_coords, dtype=DATATYPE, requires_grad=True
structure.frac_coords, dtype=TORCH_DTYPE, requires_grad=True
)
lattice = torch.tensor(
structure.lattice.matrix, dtype=DATATYPE, requires_grad=True
structure.lattice.matrix, dtype=TORCH_DTYPE, requires_grad=True
)
center_index, neighbor_index, image, distance = structure.get_neighbor_list(
r=self.atom_graph_cutoff, sites=structure.sites, numerical_tol=1e-8
Expand Down Expand Up @@ -177,7 +177,7 @@ def forward(
atomic_number=atomic_number,
atom_frac_coord=atom_frac_coord,
atom_graph=atom_graph,
neighbor_image=torch.tensor(image, dtype=DATATYPE),
neighbor_image=torch.tensor(image, dtype=TORCH_DTYPE),
directed2undirected=directed2undirected,
undirected2directed=undirected2directed,
bond_graph=bond_graph,
Expand Down Expand Up @@ -250,7 +250,7 @@ def _create_graph_fast(
"""
center_index = np.ascontiguousarray(center_index)
neighbor_index = np.ascontiguousarray(neighbor_index)
image = np.ascontiguousarray(image, dtype=np.int_)
image = np.ascontiguousarray(image, dtype=np.int64)
distance = np.ascontiguousarray(distance)
gc_saved = gc.get_threshold()
gc.set_threshold(0)
Expand Down
2 changes: 1 addition & 1 deletion chgnet/graph/crystalgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if TYPE_CHECKING:
from typing_extensions import Self

datatype = torch.float32
TORCH_DTYPE = torch.float32


class CrystalGraph:
Expand Down
67 changes: 36 additions & 31 deletions chgnet/graph/cygraph.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,70 +7,75 @@
# cython: profile=False
# distutils: language = c

import chgnet.graph.graph

import numpy as np
cimport numpy as np

import chgnet.graph.graph

from libc.stdlib cimport free


cdef extern from 'fast_converter_libraries/create_graph.c':
ctypedef struct Node:
long index
np.int64_t index
LongToDirectedEdgeList* neighbors
long num_neighbors
np.int64_t num_neighbors

ctypedef struct NodeIndexPair:
long center
long neighbor
np.int64_t center
np.int64_t neighbor

ctypedef struct UndirectedEdge:
NodeIndexPair nodes
long index
long* directed_edge_indices
long num_directed_edges
double distance
np.int64_t index
np.int64_t* directed_edge_indices
np.int64_t num_directed_edges
np.float64_t distance

ctypedef struct DirectedEdge:
NodeIndexPair nodes
long index
const long* image
long undirected_edge_index
double distance
np.int64_t index
const np.int64_t* image
np.int64_t undirected_edge_index
np.float64_t distance

ctypedef struct LongToDirectedEdgeList:
long key
np.int64_t key
DirectedEdge** directed_edges_list
int num_directed_edges_in_group

ctypedef struct ReturnElems2:
long num_nodes
long num_directed_edges
long num_undirected_edges
np.int64_t num_nodes
np.int64_t num_directed_edges
np.int64_t num_undirected_edges
Node* nodes
UndirectedEdge** undirected_edges_list
DirectedEdge** directed_edges_list

ReturnElems2* create_graph(
long* center_index,
long n_e,
long* neighbor_index,
long* image,
double* distance,
long num_atoms)
np.int64_t* center_index,
np.int64_t n_e,
np.int64_t* neighbor_index,
np.int64_t* image,
np.float64_t* distance,
np.int64_t num_atoms)

void free_LongToDirectedEdgeList_in_nodes(Node* nodes, long num_nodes)
void free_LongToDirectedEdgeList_in_nodes(Node* nodes, np.int64_t num_nodes)


LongToDirectedEdgeList** get_neighbors(Node* node)

def make_graph(
const long[::1] center_index,
const long n_e,
const long[::1] neighbor_index,
const long[:, ::1] image,
const double[::1] distance,
const long num_atoms
const np.int64_t[::1] center_index,
const np.int64_t n_e,
const np.int64_t[::1] neighbor_index,
const np.int64_t[:, ::1] image,
const np.float64_t[::1] distance,
const np.int64_t num_atoms
):
cdef ReturnElems2* returned
returned = <ReturnElems2*> create_graph(<long*> &center_index[0], n_e, <long*> &neighbor_index[0], <long*> &image[0][0], <double*> &distance[0], num_atoms)
returned = <ReturnElems2*> create_graph(<np.int64_t*> &center_index[0], n_e, <np.int64_t*> &neighbor_index[0], <np.int64_t*> &image[0][0], <np.float64_t*> &distance[0], num_atoms)

chg_DirectedEdge = chgnet.graph.graph.DirectedEdge
chg_Node = chgnet.graph.graph.Node
Expand Down
Loading

0 comments on commit 035abf2

Please sign in to comment.