Skip to content

Commit

Permalink
Fix KeyError: 'decay_fraction' and `TypeError: Object of type int64…
Browse files Browse the repository at this point in the history
… is not JSON serializable` (#169)

* fix KeyError: 'decay_fraction'

Loaded model params = 412,525
Traceback (most recent call last):
  File "{path}/trainer_reload.py", line 31, in <module>
    trainer = load_trainer(file_path)
              ^^^^^^^^^^^^^^^^^^^^^^^
  File "{path}/trainer_reload.py", line 25, in load_trainer
    trainer = Trainer.load(trainer_path)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "{path}/lib/python3.11/site-packages/chgnet/trainer/trainer.py", line 545, in load
    trainer = Trainer(model=model, **state["trainer_args"])
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "{path}/lib/python3.11/site-packages/chgnet/trainer/trainer.py", line 147, in __init__
    decay_fraction = scheduler_params.pop("decay_fraction")
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'decay_fraction'

#168

* fix TypeError: Object of type int64 is not JSON serializable when passing trainer.train(save_test_result="test-preds.json")

* test passing save_test_result in test_trainer
  • Loading branch information
janosh committed Jun 21, 2024
1 parent 9717a32 commit 7c21a94
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:

- name: Install dependencies
run: |
uv pip install cython setuptools --system
uv pip install cython 'setuptools<70' --system
python setup.py build_ext --inplace
Expand Down
12 changes: 8 additions & 4 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(
)

# Define learning rate scheduler
default_decay_frac = 1e-2
if scheduler in {"MultiStepLR", "multistep"}:
scheduler_params = kwargs.pop(
"scheduler_params",
Expand All @@ -164,8 +165,10 @@ def __init__(
self.scheduler = ExponentialLR(self.optimizer, **scheduler_params)
self.scheduler_type = "exp"
elif scheduler in {"CosineAnnealingLR", "CosLR", "Cos", "cos"}:
scheduler_params = kwargs.pop("scheduler_params", {"decay_fraction": 1e-2})
decay_fraction = scheduler_params.pop("decay_fraction")
scheduler_params = kwargs.pop(
"scheduler_params", {"decay_fraction": default_decay_frac}
)
decay_fraction = scheduler_params.pop("decay_fraction", default_decay_frac)
self.scheduler = CosineAnnealingLR(
self.optimizer,
T_max=10 * epochs, # Maximum number of iterations.
Expand All @@ -174,9 +177,10 @@ def __init__(
self.scheduler_type = "cos"
elif scheduler == "CosRestartLR":
scheduler_params = kwargs.pop(
"scheduler_params", {"decay_fraction": 1e-2, "T_0": 10, "T_mult": 2}
"scheduler_params",
{"decay_fraction": default_decay_frac, "T_0": 10, "T_mult": 2},
)
decay_fraction = scheduler_params.pop("decay_fraction")
decay_fraction = scheduler_params.pop("decay_fraction", default_decay_frac)
self.scheduler = CosineAnnealingWarmRestarts(
self.optimizer,
eta_min=decay_fraction * learning_rate,
Expand Down
14 changes: 13 additions & 1 deletion chgnet/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os

import numpy as np
import nvidia_smi
import torch
from torch import Tensor
Expand Down Expand Up @@ -115,8 +116,19 @@ def write_json(dct: dict, filepath: str) -> dict:
Returns:
written dictionary
"""

def handler(obj: object) -> int | object:
"""Convert numpy int64 to int.
Fixes TypeError: Object of type int64 is not JSON serializable
reported in https://github.com/CederGroupHub/chgnet/issues/168.
"""
if isinstance(obj, np.integer):
return int(obj)
return obj

with open(filepath, "w") as file:
json.dump(dct, file)
json.dump(dct, file, default=handler)


def mkdir(path: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ license = { text = "Modified BSD" }
dependencies = [
"ase>=3.23.0",
"cython>=3",
"numpy>=1.26",
"numpy>=1.26,<2",
"nvidia-ml-py3>=7.352.0",
"pymatgen>=2023.10.11",
"torch>=1.11.0",
Expand Down Expand Up @@ -46,7 +46,7 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] }
"chgnet.pretrained" = ["*", "**/*"]

[build-system]
requires = ["Cython", "setuptools>=65.0", "wheel"]
requires = ["Cython", "setuptools>=65,<70", "wheel"]
build-backend = "setuptools.build_meta"

[tool.ruff]
Expand Down
19 changes: 10 additions & 9 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,18 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
wandb_init_kwargs=dict(anonymous="must"),
extra_run_config=extra_run_config,
)
trainer.train(
train_loader,
val_loader,
save_dir=tmp_path,
save_test_result=tmp_path / "test-preds.json",
)
assert dict(wandb.config).items() >= extra_run_config.items()
dir_name = "test_tmp_dir"
test_dir = tmp_path / dir_name
trainer.train(train_loader, val_loader, save_dir=test_dir)
for param in chgnet.composition_model.parameters():
assert param.requires_grad is False
assert test_dir.is_dir(), "Training dir was not created"
assert tmp_path.is_dir(), "Training dir was not created"

output_files = [file.name for file in test_dir.iterdir()]
output_files = [file.name for file in tmp_path.iterdir()]
for prefix in ("epoch", "bestE_", "bestF_"):
n_matches = sum(file.startswith(prefix) for file in output_files)
assert (
Expand Down Expand Up @@ -92,16 +95,14 @@ def test_trainer_composition_model(tmp_path: Path) -> None:
learning_rate=1e-2,
epochs=5,
)
dir_name = "test_tmp_dir2"
test_dir = tmp_path / dir_name
initial_weights = chgnet.composition_model.state_dict()["fc.weight"].clone()
trainer.train(
train_loader, val_loader, save_dir=test_dir, train_composition_model=True
train_loader, val_loader, save_dir=tmp_path, train_composition_model=True
)
for param in chgnet.composition_model.parameters():
assert param.requires_grad is True

output_files = list(test_dir.iterdir())
output_files = list(tmp_path.iterdir())
weights_path = next(file for file in output_files if file.name.startswith("epoch"))
new_chgnet = CHGNet.from_file(weights_path)
for param in new_chgnet.composition_model.parameters():
Expand Down

0 comments on commit 7c21a94

Please sign in to comment.