From 73d6096790bbb1e20d5f2e1e23119cca34e54483 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 10 Sep 2024 09:09:56 -0400 Subject: [PATCH] test.yml set fail-fast: false, fix some doc strings missing Returns/Raises, change parse_vasp_dir error type FileNotFoundError->NotADirectoryError --- .github/workflows/test.yml | 2 +- .pre-commit-config.yaml | 4 ++-- chgnet/trainer/trainer.py | 41 ++++++++++++++++++++++++++++++++---- chgnet/utils/common_utils.py | 19 ++++++++++------- chgnet/utils/vasp_utils.py | 12 ++++++++++- examples/make_graphs.py | 18 ++++++++++++++-- pyproject.toml | 1 + site/make_docs.py | 2 +- tests/test_dataset.py | 26 +++++++++-------------- 9 files changed, 90 insertions(+), 35 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1f5acd73..e2be1e1b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,7 @@ on: jobs: tests: strategy: - fail-fast: true + fail-fast: false matrix: os: [ubuntu-latest, macos-14, windows-latest] version: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fb6fd17d..ed544705 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.2 + rev: v0.6.4 hooks: - id: ruff args: [--fix] @@ -48,7 +48,7 @@ repos: - svelte - repo: https://github.com/pre-commit/mirrors-eslint - rev: v9.9.0 + rev: v9.10.0 hooks: - id: eslint types: [file] diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index 056abf99..3d3cdf41 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -110,6 +110,12 @@ def __init__( that are not included in the trainer_args. Default = None **kwargs (dict): additional hyper-params for optimizer, scheduler, etc. + + Raises: + NotImplementedError: If the optimizer or scheduler is not implemented + ImportError: If wandb_path is specified but wandb is not installed + ValueError: If wandb_path is specified but not in the format + 'project/run_name' """ # Store trainer args for reproducibility self.trainer_args = { @@ -271,6 +277,9 @@ def train( wandb_log_freq ("epoch" | "batch"): Frequency of logging to wandb. 'epoch' logs once per epoch, 'batch' logs after every batch. Default = "batch" + + Raises: + ValueError: If model is not initialized """ if self.model is None: raise ValueError("Model needs to be initialized") @@ -579,7 +588,11 @@ def _validate( return {k: round(mae_error.avg, 6) for k, mae_error in mae_errors.items()} def get_best_model(self) -> CHGNet: - """Get best model recorded in the trainer.""" + """Get best model recorded in the trainer. + + Returns: + CHGNet: the model with lowest validation set energy error + """ if self.best_model is None: raise RuntimeError("the model needs to be trained first") MAE = min(self.training_history["e"]["val"]) # noqa: N806 @@ -649,7 +662,14 @@ def save_checkpoint(self, epoch: int, mae_error: dict, save_dir: str) -> None: @classmethod def load(cls, path: str) -> Self: - """Load trainer state_dict.""" + """Load trainer state_dict. + + Args: + path (str): path to the saved model + + Returns: + Trainer: the loaded trainer + """ state = torch.load(path, map_location=torch.device("cpu")) model = CHGNet.from_dict(state["model"]) print(f"Loaded model params = {sum(p.numel() for p in model.parameters()):,}") @@ -664,8 +684,21 @@ def load(cls, path: str) -> Self: return trainer @staticmethod - def move_to(obj, device) -> Tensor | list[Tensor]: - """Move object to device.""" + def move_to( + obj: Tensor | list[Tensor], device: torch.device + ) -> Tensor | list[Tensor]: + """Move object to device. + + Args: + obj (Tensor | list[Tensor]): object(s) to move to device + device (torch.device): device to move object to + + Raises: + TypeError: if obj is not a tensor or list of tensors + + Returns: + Tensor | list[Tensor]: moved object(s) + """ if torch.is_tensor(obj): return obj.to(device) if isinstance(obj, list): diff --git a/chgnet/utils/common_utils.py b/chgnet/utils/common_utils.py index 40872858..94079902 100644 --- a/chgnet/utils/common_utils.py +++ b/chgnet/utils/common_utils.py @@ -39,6 +39,9 @@ def cuda_devices_sorted_by_free_mem() -> list[int]: """List available CUDA devices sorted by increasing available memory. To get the device with the most free memory, use the last list item. + + Returns: + list[int]: CUDA device numbers sorted by increasing free memory. """ if not torch.cuda.is_available(): return [] @@ -94,10 +97,10 @@ def mae(prediction: Tensor, target: Tensor) -> Tensor: def read_json(filepath: str) -> dict: - """Read the json file. + """Read the JSON file. Args: - filepath (str): file name of json to read. + filepath (str): file name of JSON to read. Returns: dict: data stored in filepath @@ -107,14 +110,11 @@ def read_json(filepath: str) -> dict: def write_json(dct: dict, filepath: str) -> dict: - """Write the json file. + """Write the JSON file. Args: dct (dict): dictionary to write - filepath (str): file name of json to write. - - Returns: - written dictionary + filepath (str): file name of JSON to write. """ def handler(obj: object) -> int | object: @@ -122,12 +122,15 @@ def handler(obj: object) -> int | object: Fixes TypeError: Object of type int64 is not JSON serializable reported in https://github.com/CederGroupHub/chgnet/issues/168. + + Returns: + int | object: object for serialization """ if isinstance(obj, np.integer): return int(obj) return obj - with open(filepath, "w") as file: + with open(filepath, mode="w") as file: json.dump(dct, file, default=handler) diff --git a/chgnet/utils/vasp_utils.py b/chgnet/utils/vasp_utils.py index e0f6b7b6..bfb54a07 100644 --- a/chgnet/utils/vasp_utils.py +++ b/chgnet/utils/vasp_utils.py @@ -30,9 +30,16 @@ def parse_vasp_dir( Exception to VASP calculation that did not achieve electronic convergence. Default = True save_path (str): path to save the parsed VASP labels + + Raises: + NotADirectoryError: if the base_dir is not a directory + + Returns: + dict: a dictionary of lists with keys for structure, uncorrected_total_energy, + energy_per_atom, force, magmom, stress. """ if os.path.isdir(base_dir) is False: - raise FileNotFoundError(f"{base_dir=} is not a directory") + raise NotADirectoryError(f"{base_dir=} is not a directory") oszicar_path = zpath(f"{base_dir}/OSZICAR") vasprun_path = zpath(f"{base_dir}/vasprun.xml") @@ -170,6 +177,9 @@ def solve_charge_by_mag( (3.5, 4.2): 3, (4.2, 5): 2 )) + + Returns: + Structure: pymatgen Structure with oxidation states assigned based on magmoms. """ out_structure = structure.copy() out_structure.remove_oxidation_states() diff --git a/examples/make_graphs.py b/examples/make_graphs.py index 6093c4f6..8aacc2a5 100644 --- a/examples/make_graphs.py +++ b/examples/make_graphs.py @@ -58,8 +58,22 @@ def make_graphs( make_partition(labels, graph_dir, train_ratio, val_ratio) -def make_one_graph(mp_id: str, graph_id: str, data, graph_dir) -> dict | bool: - """Convert a structure to a CrystalGraph and save it.""" +def make_one_graph( + mp_id: str, graph_id: str, data: StructureJsonData, graph_dir: str +) -> dict | bool: + """Convert a structure to a CrystalGraph and save it. + + Args: + mp_id (str): The material id. + graph_id (str): The graph id. + data (StructureJsonData): The dataset. Warning: Dicts are popped from the data, + i.e. modifying the data in place. + graph_dir (str): The directory to save the graphs. + + Returns: + dict | bool: The label dictionary if the graph is saved successfully, False + otherwise. + """ dct = data.data[mp_id].pop(graph_id) struct = Structure.from_dict(dct.pop("structure")) try: diff --git a/pyproject.toml b/pyproject.toml index cec86fd6..2c1f50e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,7 @@ docstring-code-format = true "ANN201", "D100", "D103", + "DOC201", # doc string missing Return section "FBT001", "FBT002", "INP001", diff --git a/site/make_docs.py b/site/make_docs.py index 79b12b26..a65b8063 100644 --- a/site/make_docs.py +++ b/site/make_docs.py @@ -49,5 +49,5 @@ markdown = markdown.replace( "\n**Global Variables**\n---------------\n- **TYPE_CHECKING**\n\n", "" ) - with open(path, "w") as file: + with open(path, mode="w") as file: file.write(markdown) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 475aba15..0c23d933 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -20,23 +20,17 @@ def structure_data() -> StructureData: """Create a graph with 3 nodes and 3 directed edges.""" random.seed(42) - structures, energies, forces, stresses, magmoms, structure_ids = ( - [], - [], - [], - [], - [], - [], - ) + structures, energies, forces = [], [], [] + stresses, magmoms, structure_ids = [], [], [] + for index in range(100): - struct = NaCl.copy() - struct.perturb(0.1) - structures.append(struct) - energies.append(np.random.random(1)) - forces.append(np.random.random([2, 3])) - stresses.append(np.random.random([3, 3])) - magmoms.append(np.random.random([2, 1])) - structure_ids.append(index) + structures += [NaCl.copy().perturb(0.1)] + energies += [np.random.random(1)] + forces += [np.random.random([2, 3])] + stresses += [np.random.random([3, 3])] + magmoms += [np.random.random([2, 1])] + structure_ids += [index] + return StructureData( structures=structures, energies=energies,