Skip to content

Commit

Permalink
test passing save_test_result in test_trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Jun 21, 2024
1 parent 9df59b2 commit 9484896
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,17 @@ def test_trainer(tmp_path: Path) -> None:
learning_rate=1e-2,
epochs=5,
)
dir_name = "test_tmp_dir"
test_dir = tmp_path / dir_name
trainer.train(train_loader, val_loader, save_dir=test_dir)
trainer.train(
train_loader,
val_loader,
save_dir=tmp_path,
save_test_result=tmp_path / "test-preds.json",
)
for param in chgnet.composition_model.parameters():
assert param.requires_grad is False
assert test_dir.is_dir(), "Training dir was not created"
assert tmp_path.is_dir(), "Training dir was not created"

output_files = [file.name for file in test_dir.iterdir()]
output_files = [file.name for file in tmp_path.iterdir()]
for prefix in ("epoch", "bestE_", "bestF_"):
n_matches = sum(file.startswith(prefix) for file in output_files)
assert (
Expand All @@ -79,16 +82,14 @@ def test_trainer_composition_model(tmp_path: Path) -> None:
learning_rate=1e-2,
epochs=5,
)
dir_name = "test_tmp_dir2"
test_dir = tmp_path / dir_name
initial_weights = chgnet.composition_model.state_dict()["fc.weight"].clone()
trainer.train(
train_loader, val_loader, save_dir=test_dir, train_composition_model=True
train_loader, val_loader, save_dir=tmp_path, train_composition_model=True
)
for param in chgnet.composition_model.parameters():
assert param.requires_grad is True

output_files = list(test_dir.iterdir())
output_files = list(tmp_path.iterdir())
weights_path = next(file for file in output_files if file.name.startswith("epoch"))
new_chgnet = CHGNet.from_file(weights_path)
for param in new_chgnet.composition_model.parameters():
Expand Down

0 comments on commit 9484896

Please sign in to comment.