Skip to content

Commit

Permalink
fix clashing var names for counters in nested _validate() loops
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Mar 6, 2024
1 parent d426a55 commit 6a5d96a
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 25 deletions.
32 changes: 16 additions & 16 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Trainer:

def __init__(
self,
model: nn.Module | None = None,
model: CHGNet | None = None,
targets: TrainTask = "ef",
energy_loss_ratio: float = 1,
force_loss_ratio: float = 1,
Expand Down Expand Up @@ -382,11 +382,11 @@ def _validate(
test_pred = []

end = time.perf_counter()
for idx, (graphs, targets) in enumerate(val_loader):
for ii, (graphs, targets) in enumerate(val_loader):
if "f" in self.targets or "s" in self.targets:
for g in graphs:
for graph in graphs:
requires_force = "f" in self.targets
g.atom_frac_coord.requires_grad = requires_force
graph.atom_frac_coord.requires_grad = requires_force
graphs = [g.to(self.device) for g in graphs]
targets = {k: self.move_to(v, self.device) for k, v in targets.items()}
else:
Expand All @@ -407,33 +407,33 @@ def _validate(
combined_loss[f"{key}_MAE_size"],
)
if is_test and test_result_save_path:
for idx, graph_i in enumerate(graphs):
for jj, graph_i in enumerate(graphs):
tmp = {
"mp_id": graph_i.mp_id,
"graph_id": graph_i.graph_id,
"energy": {
"ground_truth": targets["e"][idx].cpu().detach().tolist(),
"prediction": prediction["e"][idx].cpu().detach().tolist(),
"ground_truth": targets["e"][jj].cpu().detach().tolist(),
"prediction": prediction["e"][jj].cpu().detach().tolist(),
},
}
if "f" in self.targets:
tmp["force"] = {
"ground_truth": targets["f"][idx].cpu().detach().tolist(),
"prediction": prediction["f"][idx].cpu().detach().tolist(),
"ground_truth": targets["f"][jj].cpu().detach().tolist(),
"prediction": prediction["f"][jj].cpu().detach().tolist(),
}
if "s" in self.targets:
tmp["stress"] = {
"ground_truth": targets["s"][idx].cpu().detach().tolist(),
"prediction": prediction["s"][idx].cpu().detach().tolist(),
"ground_truth": targets["s"][jj].cpu().detach().tolist(),
"prediction": prediction["s"][jj].cpu().detach().tolist(),
}
if "m" in self.targets:
if targets["m"][idx] is not None:
m_ground_truth = targets["m"][idx].cpu().detach().tolist()
if targets["m"][jj] is not None:
m_ground_truth = targets["m"][jj].cpu().detach().tolist()
else:
m_ground_truth = None
tmp["mag"] = {
"ground_truth": m_ground_truth,
"prediction": prediction["m"][idx].cpu().detach().tolist(),
"prediction": prediction["m"][jj].cpu().detach().tolist(),
}
test_pred.append(tmp)

Expand All @@ -445,10 +445,10 @@ def _validate(
batch_time.update(time.perf_counter() - end)
end = time.perf_counter()

if (idx + 1) % self.print_freq == 0:
if (ii + 1) % self.print_freq == 0:
name = "Test" if is_test else "Val"
message = (
f"{name}: [{idx + 1}/{len(val_loader)}] | "
f"{name}: [{ii + 1}/{len(val_loader)}] | "
f"Time ({batch_time.avg:.3f}) | "
f"Loss {losses.val:.4f}({losses.avg:.4f}) | MAE "
)
Expand Down
8 changes: 4 additions & 4 deletions examples/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@
"\n",
"# Relax the perturbed structure\n",
"result = relaxer.relax(structure, verbose=True)\n",
"print(f\"Relaxed structure:\\n\")\n",
"print(\"Relaxed structure:\\n\")\n",
"print(result[\"final_structure\"])"
]
},
Expand Down Expand Up @@ -1858,9 +1858,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python (tf)",
"display_name": "py311",
"language": "python",
"name": "myenv"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -1872,7 +1872,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion examples/fine_tuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@
}
],
"source": [
"from chgnet.trainer import Trainer\n",
"from chgnet.model import CHGNet\n",
"from chgnet.trainer import Trainer\n",
"\n",
"# Load pretrained CHGNet\n",
"chgnet = CHGNet.load()"
Expand Down
3 changes: 2 additions & 1 deletion tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def test_crystal_graph_converter_algorithm(algorithm: Literal["legacy", "fast"])
assert converter.algorithm == algorithm


def test_crystal_graph_converter_warns(_set_make_graph: None):
@pytest.mark.usefixtures("_set_make_graph")
def test_crystal_graph_converter_warns():
with pytest.warns(UserWarning, match="Unknown algorithm='foobar', using `legacy`"):
CrystalGraphConverter(
atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm="foobar"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_structure_data(structure_data: StructureData) -> None:


def test_data_loader(structure_data: StructureData) -> None:
train_loader, val_loader, test_loader = get_train_val_test_loader(
train_loader, _val_loader, _test_loader = get_train_val_test_loader(
structure_data, batch_size=16, train_ratio=0.9, val_ratio=0.05
)
graphs, targets = next(iter(train_loader))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

def test_trainer(tmp_path: Path) -> None:
chgnet = CHGNet.load()
train_loader, val_loader, test_loader = get_train_val_test_loader(
train_loader, val_loader, _test_loader = get_train_val_test_loader(
data, batch_size=16, train_ratio=0.9, val_ratio=0.05
)
trainer = Trainer(
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_trainer_composition_model(tmp_path: Path) -> None:
chgnet = CHGNet.load()
for param in chgnet.composition_model.parameters():
assert param.requires_grad is False
train_loader, val_loader, test_loader = get_train_val_test_loader(
train_loader, val_loader, _test_loader = get_train_val_test_loader(
data, batch_size=16, train_ratio=0.9, val_ratio=0.05
)
trainer = Trainer(
Expand Down

0 comments on commit 6a5d96a

Please sign in to comment.