Skip to content

Commit

Permalink
refactor test_dataset_no_shuffling
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Apr 23, 2024
1 parent e2b6ef2 commit 11315e5
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 36 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.7
rev: v0.4.1
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -46,7 +46,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.0.0
rev: v9.1.1
hooks:
- id: eslint
types: [file]
Expand Down
6 changes: 3 additions & 3 deletions chgnet/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, index: int, info: dict | None = None) -> None:
self.info = info
self.neighbors: dict[int, list[DirectedEdge | UndirectedEdge]] = {}

def add_neighbor(self, index, edge):
def add_neighbor(self, index, edge) -> None:
"""Draw an directed edge between self and the node specified by index.
Args:
Expand All @@ -44,7 +44,7 @@ def __init__(
self.index = index
self.info = info

def __repr__(self):
def __repr__(self) -> str:
"""String representation of this edge."""
nodes, index, info = self.nodes, self.index, self.info
return f"{type(self).__name__}({nodes=}, {index=}, {info=})"
Expand Down Expand Up @@ -336,7 +336,7 @@ def as_dict(self):
"undirected_edges_list": self.undirected_edges_list,
}

def to(self, filename="graph.json"):
def to(self, filename="graph.json") -> None:
"""Save graph dictionary to file."""
write_json(self.as_dict(), filename)

Expand Down
9 changes: 5 additions & 4 deletions chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

if TYPE_CHECKING:
from collections.abc import Sequence
from pathlib import Path

from chgnet.graph.crystalgraph import CrystalGraph

Expand Down Expand Up @@ -199,7 +200,7 @@ def get_site_energies(self, graphs: list[CrystalGraph]):
for graph in graphs
]

def initialize_from(self, dataset: str):
def initialize_from(self, dataset: str) -> None:
"""Initialize pre-fitted weights from a dataset."""
if dataset in ["MPtrj", "MPtrj_e"]:
self.initialize_from_MPtrj()
Expand All @@ -208,7 +209,7 @@ def initialize_from(self, dataset: str):
else:
raise NotImplementedError(f"{dataset=} not supported yet")

def initialize_from_MPtrj(self):
def initialize_from_MPtrj(self) -> None:
"""Initialize pre-fitted weights from MPtrj dataset."""
state_dict = collections.OrderedDict()
state_dict["weight"] = torch.tensor(
Expand Down Expand Up @@ -313,7 +314,7 @@ def initialize_from_MPtrj(self):
self.is_intensive = True
self.fitted = True

def initialize_from_MPF(self):
def initialize_from_MPF(self) -> None:
"""Initialize pre-fitted weights from MPF dataset."""
state_dict = collections.OrderedDict()
state_dict["weight"] = torch.tensor(
Expand Down Expand Up @@ -418,7 +419,7 @@ def initialize_from_MPF(self):
self.is_intensive = False
self.fitted = True

def initialize_from_numpy(self, file_name):
def initialize_from_numpy(self, file_name: str | Path) -> None:
"""Initialize pre-fitted weights from numpy file."""
atom_ref_np = np.load(file_name)
state_dict = collections.OrderedDict()
Expand Down
2 changes: 1 addition & 1 deletion examples/crystaltoolkit_relax_viewer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@
" return structure, fig\n",
"\n",
"\n",
"app.run(height=800, use_reloader=False)\n"
"app.run(height=800, use_reloader=False)"
]
}
],
Expand Down
37 changes: 11 additions & 26 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,33 +92,18 @@ def test_structure_data_inconsistent_length():


def test_dataset_no_shuffling():
structures, energies, forces, stresses, magmoms, structure_ids = (
[],
[],
[],
[],
[],
[],
)
for index in range(100):
struct = NaCl.copy()
struct.perturb(0.1)
structures.append(struct)
energies.append(np.random.random(1))
forces.append(np.random.random([2, 3]))
stresses.append(np.random.random([3, 3]))
magmoms.append(np.random.random([2, 1]))
structure_ids.append(index)
n_samples = 100
structure_ids = list(range(n_samples))

structure_data = StructureData(
structures=structures,
energies=energies,
forces=forces,
stresses=stresses,
magmoms=magmoms,
structures=[NaCl.copy().perturb(0.1) for _ in range(n_samples)],
energies=np.random.random(n_samples),
forces=np.random.random([n_samples, 2, 3]),
stresses=np.random.random([n_samples, 3, 3]),
magmoms=np.random.random([n_samples, 2, 1]),
structure_ids=structure_ids,
shuffle=False,
)

assert structure_data[0][0].mp_id == 0
assert structure_data[1][0].mp_id == 1
assert structure_data[2][0].mp_id == 2
sample_ids = [data[0].mp_id for data in structure_data]
# shuffle=False means structure_ids should be in order
assert sample_ids == structure_ids

0 comments on commit 11315e5

Please sign in to comment.