diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index 780654e..e606fbd 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -32,7 +32,7 @@ class Trainer: def __init__( self, - model: nn.Module | None = None, + model: CHGNet | None = None, targets: TrainTask = "ef", energy_loss_ratio: float = 1, force_loss_ratio: float = 1, @@ -382,11 +382,11 @@ def _validate( test_pred = [] end = time.perf_counter() - for idx, (graphs, targets) in enumerate(val_loader): + for ii, (graphs, targets) in enumerate(val_loader): if "f" in self.targets or "s" in self.targets: - for g in graphs: + for graph in graphs: requires_force = "f" in self.targets - g.atom_frac_coord.requires_grad = requires_force + graph.atom_frac_coord.requires_grad = requires_force graphs = [g.to(self.device) for g in graphs] targets = {k: self.move_to(v, self.device) for k, v in targets.items()} else: @@ -407,33 +407,33 @@ def _validate( combined_loss[f"{key}_MAE_size"], ) if is_test and test_result_save_path: - for idx, graph_i in enumerate(graphs): + for jj, graph_i in enumerate(graphs): tmp = { "mp_id": graph_i.mp_id, "graph_id": graph_i.graph_id, "energy": { - "ground_truth": targets["e"][idx].cpu().detach().tolist(), - "prediction": prediction["e"][idx].cpu().detach().tolist(), + "ground_truth": targets["e"][jj].cpu().detach().tolist(), + "prediction": prediction["e"][jj].cpu().detach().tolist(), }, } if "f" in self.targets: tmp["force"] = { - "ground_truth": targets["f"][idx].cpu().detach().tolist(), - "prediction": prediction["f"][idx].cpu().detach().tolist(), + "ground_truth": targets["f"][jj].cpu().detach().tolist(), + "prediction": prediction["f"][jj].cpu().detach().tolist(), } if "s" in self.targets: tmp["stress"] = { - "ground_truth": targets["s"][idx].cpu().detach().tolist(), - "prediction": prediction["s"][idx].cpu().detach().tolist(), + "ground_truth": targets["s"][jj].cpu().detach().tolist(), + "prediction": prediction["s"][jj].cpu().detach().tolist(), } if "m" in self.targets: - if targets["m"][idx] is not None: - m_ground_truth = targets["m"][idx].cpu().detach().tolist() + if targets["m"][jj] is not None: + m_ground_truth = targets["m"][jj].cpu().detach().tolist() else: m_ground_truth = None tmp["mag"] = { "ground_truth": m_ground_truth, - "prediction": prediction["m"][idx].cpu().detach().tolist(), + "prediction": prediction["m"][jj].cpu().detach().tolist(), } test_pred.append(tmp) @@ -445,10 +445,10 @@ def _validate( batch_time.update(time.perf_counter() - end) end = time.perf_counter() - if (idx + 1) % self.print_freq == 0: + if (ii + 1) % self.print_freq == 0: name = "Test" if is_test else "Val" message = ( - f"{name}: [{idx + 1}/{len(val_loader)}] | " + f"{name}: [{ii + 1}/{len(val_loader)}] | " f"Time ({batch_time.avg:.3f}) | " f"Loss {losses.val:.4f}({losses.avg:.4f}) | MAE " ) diff --git a/examples/basics.ipynb b/examples/basics.ipynb index 336b043..7b0101d 100644 --- a/examples/basics.ipynb +++ b/examples/basics.ipynb @@ -274,7 +274,7 @@ "\n", "# Relax the perturbed structure\n", "result = relaxer.relax(structure, verbose=True)\n", - "print(f\"Relaxed structure:\\n\")\n", + "print(\"Relaxed structure:\\n\")\n", "print(result[\"final_structure\"])" ] }, @@ -1858,9 +1858,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python (tf)", + "display_name": "py311", "language": "python", - "name": "myenv" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -1872,7 +1872,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.10" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/examples/fine_tuning.ipynb b/examples/fine_tuning.ipynb index 04ea00f..c8b26eb 100644 --- a/examples/fine_tuning.ipynb +++ b/examples/fine_tuning.ipynb @@ -287,8 +287,8 @@ } ], "source": [ - "from chgnet.trainer import Trainer\n", "from chgnet.model import CHGNet\n", + "from chgnet.trainer import Trainer\n", "\n", "# Load pretrained CHGNet\n", "chgnet = CHGNet.load()" diff --git a/tests/test_converter.py b/tests/test_converter.py index eec38b1..1e8c757 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -47,7 +47,8 @@ def test_crystal_graph_converter_algorithm(algorithm: Literal["legacy", "fast"]) assert converter.algorithm == algorithm -def test_crystal_graph_converter_warns(_set_make_graph: None): +@pytest.mark.usefixtures("_set_make_graph") +def test_crystal_graph_converter_warns(): with pytest.warns(UserWarning, match="Unknown algorithm='foobar', using `legacy`"): CrystalGraphConverter( atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm="foobar" diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 736a6be..9b8f192 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -56,7 +56,7 @@ def test_structure_data(structure_data: StructureData) -> None: def test_data_loader(structure_data: StructureData) -> None: - train_loader, val_loader, test_loader = get_train_val_test_loader( + train_loader, _val_loader, _test_loader = get_train_val_test_loader( structure_data, batch_size=16, train_ratio=0.9, val_ratio=0.05 ) graphs, targets = next(iter(train_loader)) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index fe28752..bcf44f6 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -38,7 +38,7 @@ def test_trainer(tmp_path: Path) -> None: chgnet = CHGNet.load() - train_loader, val_loader, test_loader = get_train_val_test_loader( + train_loader, val_loader, _test_loader = get_train_val_test_loader( data, batch_size=16, train_ratio=0.9, val_ratio=0.05 ) trainer = Trainer( @@ -68,7 +68,7 @@ def test_trainer_composition_model(tmp_path: Path) -> None: chgnet = CHGNet.load() for param in chgnet.composition_model.parameters(): assert param.requires_grad is False - train_loader, val_loader, test_loader = get_train_val_test_loader( + train_loader, val_loader, _test_loader = get_train_val_test_loader( data, batch_size=16, train_ratio=0.9, val_ratio=0.05 ) trainer = Trainer(